TFRecord 使用

mac2022-06-30  85

tfrecord生成

import os import xmltodict import tensorflow as tf import numpy as np dir_path = 'F:\数据存储\VOCdevkit\VOC2012\Annotations' dirs = os.listdir(dir_path) imgs_dir = "F:\数据存储\VOCdevkit\VOC2012\JPEGImages" out_path = 'F:\数据存储\VOCdevkit\\voc2012.tfrecord' classes = [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ] sess = tf.Session() def get_and_resize_img(img_file): ''' 将图片设置为224*224的尺寸大小 返回图片,返回变化倍数,shape ''' img = tf.read_file(imgs_dir + '/' + img_file) img = tf.image.decode_jpeg(img) shape_old = sess.run(img).shape resized_img = tf.image.resize_images(img, [224, 224], method=0) resized_img = sess.run(resized_img) resized_img = np.asarray(resized_img, dtype='uint8') resized_img_str = resized_img.tostring() shape_new = resized_img.shape # print(shape_new) # print(shape_old) # print('shape_old的长是width是维度1,height是维度0') w_scale = shape_new[0] / shape_old[1] h_scale = shape_new[1] / shape_old[0] return resized_img_str, w_scale, h_scale, shape_new writer = tf.python_io.TFRecordWriter(out_path) i = 0 for file in dirs: i = i + 1 # if i > 1000: # break with open(dir_path + '/' + file) as xml_txt: doc = xmltodict.parse(xml_txt.read()) img_file_name = file.split('.')[0] resized_img_str, w_scale, h_scale, shape = get_and_resize_img(img_file_name + '.jpg') img_obtain_classes = [] y_mins = [] x_mins = [] y_maxes = [] x_maxes = [] if type(doc['annotation']["object"]).__name__ == 'OrderedDict': if doc['annotation']["object"]['name'] in classes: img_obtain_classes.append(classes.index(doc['annotation']["object"]['name'])) y_mins.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymin']))) x_mins.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmin']))) y_maxes.append(float(h_scale * int(doc['annotation']["object"]['bndbox']['ymax']))) x_maxes.append(float(w_scale * int(doc['annotation']["object"]['bndbox']['xmax']))) else: for one_object in doc['annotation']["object"]: # ['annotation']["object"][0]["name"] if one_object['name'] in classes: img_obtain_classes.append(classes.index(one_object['name'])) y_mins.append(float(h_scale * int(one_object['bndbox']['ymin']))) x_mins.append(float(w_scale * int(one_object['bndbox']['xmin']))) y_maxes.append(float(h_scale * int(one_object['bndbox']['ymax']))) x_maxes.append(float(w_scale * int(one_object['bndbox']['xmax']))) # example = tf.train.Example(features=tf.train.Features(feature={ # 'name': tf.train.Feature(bytes_list=tf.train.BytesList(value=[name])), # 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])), # 'data': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str])) # } # )) img_file_name = bytes(img_file_name, encoding='utf8') example = tf.train.Example(features=tf.train.Features(feature={ 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])), 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])), 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=img_obtain_classes)), 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)), 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)), 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)), 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str])) })) writer.write(example.SerializeToString()) writer.close() sess.close() print('ok')

tfrecord读取

import tensorflow as tf import numpy as np from matplotlib import pyplot as plt # import sys # # sys.path.append("..") classes = [ "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor" ] # 'filename': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_file_name])), # 'shape': tf.train.Feature(int64_list=tf.train.Int64List(value=[shape[0], shape[1], shape[2]])), # 'classes': tf.train.Feature(int64_list=tf.train.Int64List(value=np.array(img_obtain_classes))), # 'y_mins': tf.train.Feature(float_list=tf.train.FloatList(value=y_mins)), # 各个 object 的 ymin # 'x_mins': tf.train.Feature(float_list=tf.train.FloatList(value=x_mins)), # 'y_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=y_maxes)), # 'x_maxes': tf.train.Feature(float_list=tf.train.FloatList(value=x_maxes)), # 'encoded': tf.train.Feature(bytes_list=tf.train.BytesList(value=[resized_img_str])) def _parse_record(example_proto): features = { 'filename': tf.FixedLenFeature([], tf.string), 'shape': tf.FixedLenFeature([3], tf.int64), 'classes': tf.VarLenFeature(tf.int64), 'y_mins': tf.VarLenFeature(tf.float32), 'x_mins': tf.VarLenFeature(tf.float32), 'y_maxes': tf.VarLenFeature(tf.float32), 'x_maxes': tf.VarLenFeature(tf.float32), 'encoded': tf.FixedLenFeature((), tf.string) } parsed_features = tf.parse_single_example(example_proto, features=features) return parsed_features def read_test(input_file): # 用 dataset 读取 tfrecord 文件 dataset = tf.data.TFRecordDataset(input_file) dataset = dataset.map(_parse_record) iterator = dataset.make_initializable_iterator() max_value = tf.placeholder(tf.int64, shape=[]) with tf.Session() as sess: sess.run(iterator.initializer, feed_dict={max_value: 100}) for i in range(2): features = sess.run(iterator.get_next()) name = features['filename'] name = name.decode() shape = features['shape'] classes = features['classes'] y_mins = features['y_mins'] x_mins = features['x_mins'] y_maxes = features['y_maxes'] x_maxes = features['x_maxes'] # name = name.decode() img_data = features['encoded'] print(len(img_data)) print('=======') print("shape", shape) print("name", name) print("classes", classes.values) print("y_mins", y_mins.values) print("x_mins", x_mins.values) print("y_maxes", y_maxes.values) print("x_maxes", x_maxes.values) img_data = np.fromstring(img_data, dtype=np.uint8) image_data = np.reshape(img_data, shape) print("img_data", image_data) # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组 # img_data = np.fromstring(img_data, dtype=np.uint8) # image_data = np.reshape(img_data, shape) # # plt.figure() # # 显示图片 plt.imshow(image_data) plt.show() read_test('F:\数据存储\VOCdevkit\\voc2012.tfrecord')

尺寸不固定矩阵的存储和读取

import json import jieba import tensorflow as tf with open('../data_save/words_info.txt', 'r', encoding='utf-8') as file: dic = json.loads(file.read()) all_words_word2id = dic["all_words_word2id"] stop_words = [] with open('./stop_words.txt', encoding='utf-8') as f: line = f.readline() while line: stop_words.append(line[:-1]) line = f.readline() stop_words = set(stop_words) print('停用词读取完毕,共{n}个单词'.format(n=len(stop_words))) dir_path = 'F:\\数据存储\新闻语料\\news2016zh_train.json' dir_path_test = 'F:\\数据存储\新闻语料\\news2016zh_valid.json' out_path = 'F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord' def getCutSequnce(line): # 使用jieba 进行中文分词 raw_words = list(jieba.cut(line, cut_all=False)) # 存储一句话的分词结果 raw_word_list = [] # 去除停用词 for word in raw_words: if word not in stop_words and word not in ['www', 'com', 'http']: raw_word_list.append(word) return raw_word_list writer = tf.python_io.TFRecordWriter(out_path) i = 0 with open(dir_path, encoding='utf-8') as txt: one_dic = txt.readline() while one_dic: i = i + 1 if i > 10000: break if (i % 1000) == 0: print(i) one_dic_json = json.loads(one_dic) title = one_dic_json['title'] content = one_dic_json['content'] if len(content) > 3000: one_dic = txt.readline() continue one_dic = txt.readline() if len(title) == 0 or len(content) == 0: continue title_list = getCutSequnce(title) content_list = getCutSequnce(content) title_list_index = [] for one in title_list: try: title_list_index.append(all_words_word2id[one]) except: pass content_list_index = [] for one_word in content_list: try: content_list_index.append(all_words_word2id[one_word]) except: pass example = tf.train.Example(features=tf.train.Features(feature={ 'title': tf.train.Feature(int64_list=tf.train.Int64List(value=title_list_index)), 'content': tf.train.Feature(int64_list=tf.train.Int64List(value=content_list_index)) })) writer.write(example.SerializeToString()) import tensorflow as tf import numpy as np def _parse_record(example_proto): features = { 'title': tf.VarLenFeature(tf.int64), 'content': tf.VarLenFeature(dtype=tf.int64) } parsed_features = tf.parse_single_example(example_proto, features=features) return parsed_features def read_test(input_file): # 用 dataset 读取 tfrecord 文件 dataset = tf.data.TFRecordDataset(input_file) dataset = dataset.map(_parse_record) iterator = dataset.make_initializable_iterator() with tf.Session() as sess: sess.run(iterator.initializer) for i in range(5): features = sess.run(iterator.get_next()) name = features['title'] content = features['content'] print("xx", content) print("xx", np.array(content).shape) # 从 bytes 数组中加载图片原始数据,并重新 reshape.它的结果是 ndarray 数组 read_test('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord')

统计数据条数

import tensorflow as tf def total_sample(file_name): sample_nums = 0 for record in tf.python_io.tf_record_iterator(file_name): sample_nums += 1 return sample_nums result = total_sample('F:\\数据存储\新闻语料\\news2016zh_train_new.tfrecord') print(result)

转载于:https://www.cnblogs.com/panfengde/p/11302960.html

相关资源:Tensorflow中使用tfrecord方式读取数据的方法
最新回复(0)