其实想要使用别人训练好的模型很简单,确定模型输入输出张量名,跑一下就可以: import numpy as np import tensorflow as tf import cv2 as cv import os
def main(): folder_path = r'D:\share\samples' result_path = r'D:\share\test_result' if not os.path.exists(result_path): os.mkdir(result_path)
usedlabel = [1, 3, 6, 8, 10, 13] vehiclelabel = [3, 6, 8] font = cv.FONT_HERSHEY_SIMPLEX
class_name = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck"] model_path = r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb'
# Read the graph. with tf.gfile.FastGFile( model_path, 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read())
with tf.Session() as sess: # Restore session sess.graph.as_default() tf.import_graph_def(graph_def, name='')
for sub in os.listdir(folder_path): if not sub.endswith('.jpg'): continue img_name = os.path.join(folder_path, sub)
result_name = os.path.join(result_path, sub) img = cv.imread(img_name) pad_img = pad_to_square(img, [640, 640]) change_img = pad_img[:, :, [2, 1, 0]] # BGR2RGB #cv.namedWindow("pad_img") # Run the model out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'), sess.graph.get_tensor_by_name('detection_scores:0'), sess.graph.get_tensor_by_name('detection_boxes:0'), sess.graph.get_tensor_by_name('detection_classes:0')], feed_dict={'image_tensor:0': pad_img.reshape(1, change_img.shape[0], change_img.shape[1], 3)})
# Visualize detected bounding boxes. num_detections = int(out[0][0]) classlist = [] bboxlist = []
for i in range(num_detections): classId = int(out[3][0][i]) score = float(out[1][0][i]) bbox = [float(v) for v in out[2][0][i]] if score < 0.5: # 得分小于此不标 continue x = bbox[1] * pad_img.shape[0] y = bbox[0] * pad_img.shape[1] right = bbox[3] * pad_img.shape[0] bottom = bbox[2] * pad_img.shape[1]
# if (classId in vehiclelabel) and (right - x < 40 or bottom - y < 40): # continue
classlist.append(classId) bboxlist.append([x, y, right, bottom])
assert len(classlist) == len(bboxlist)
for i, box in enumerate(bboxlist): p1 = (int(box[0]), int(box[1])) p2 = (int(box[2]), int(box[3])) if classlist[i] == 1: cv.rectangle(pad_img, p1, p2, (255, 255, 0), thickness=2) cv.putText(pad_img, class_name[classlist[i]-1], p1, font, 0.8, (255, 255, 0), 2, False) elif classlist[i] == 3: cv.rectangle(pad_img, p1, p2, (0, 0, 255), thickness=2) cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 0, 255), 2, False) elif classlist[i] == 6: cv.rectangle(pad_img, p1, p2, (0, 255, 255), thickness=2) cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 255, 255), 2, False) elif classlist[i] == 8: cv.rectangle(pad_img, p1, p2, (255, 0, 255), thickness=2) cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (255, 0, 255), 2, False) else: pass cv.imwrite(result_name, pad_img)
if __name__ == '__main__': main()
读取tensorflow.pb,输出节点名,以便确定输入输出:
import tensorflow as tf gf = tf.GraphDef() gf.ParseFromString(open(r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb', 'rb').read()) for n in gf.node: print(n.name + ' ===> ' + n.op)