tensorflow 使用别人的模型

mac2024-04-21  42

其实想要使用别人训练好的模型很简单,确定模型输入输出张量名,跑一下就可以: import numpy as np import tensorflow as tf import cv2 as cv import os

def main():     folder_path = r'D:\share\samples'     result_path = r'D:\share\test_result'     if not os.path.exists(result_path):         os.mkdir(result_path)

    usedlabel = [1, 3, 6, 8, 10, 13]     vehiclelabel = [3, 6, 8]     font = cv.FONT_HERSHEY_SIMPLEX

    class_name = ["person", "bicycle", "car", "motorbike", "aeroplane", "bus", "train", "truck"]     model_path = r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb'

    # Read the graph.     with tf.gfile.FastGFile(             model_path, 'rb') as f:         graph_def = tf.GraphDef()         graph_def.ParseFromString(f.read())

    with tf.Session() as sess:         # Restore session         sess.graph.as_default()         tf.import_graph_def(graph_def, name='')

        for sub in os.listdir(folder_path):             if not sub.endswith('.jpg'):                 continue             img_name = os.path.join(folder_path, sub)

            result_name = os.path.join(result_path, sub)             img = cv.imread(img_name)             pad_img = pad_to_square(img, [640, 640])             change_img = pad_img[:, :, [2, 1, 0]]  # BGR2RGB             #cv.namedWindow("pad_img")             # Run the model             out = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),                             sess.graph.get_tensor_by_name('detection_scores:0'),                             sess.graph.get_tensor_by_name('detection_boxes:0'),                             sess.graph.get_tensor_by_name('detection_classes:0')],                            feed_dict={'image_tensor:0': pad_img.reshape(1, change_img.shape[0], change_img.shape[1], 3)})

            # Visualize detected bounding boxes.             num_detections = int(out[0][0])             classlist = []             bboxlist = []

            for i in range(num_detections):                 classId = int(out[3][0][i])                 score = float(out[1][0][i])                 bbox = [float(v) for v in out[2][0][i]]                 if score < 0.5:  # 得分小于此不标                     continue                 x = bbox[1] * pad_img.shape[0]                 y = bbox[0] * pad_img.shape[1]                 right = bbox[3] * pad_img.shape[0]                 bottom = bbox[2] * pad_img.shape[1]

                # if (classId in vehiclelabel) and (right - x < 40 or bottom - y < 40):                 #     continue

                classlist.append(classId)                 bboxlist.append([x, y, right, bottom])

            assert len(classlist) == len(bboxlist)

            for i, box in enumerate(bboxlist):                 p1 = (int(box[0]), int(box[1]))                 p2 = (int(box[2]), int(box[3]))                 if classlist[i] == 1:                     cv.rectangle(pad_img, p1, p2, (255, 255, 0), thickness=2)                     cv.putText(pad_img, class_name[classlist[i]-1], p1, font, 0.8, (255, 255, 0), 2, False)                 elif classlist[i] == 3:                     cv.rectangle(pad_img, p1, p2, (0, 0, 255), thickness=2)                     cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 0, 255), 2, False)                 elif classlist[i] == 6:                     cv.rectangle(pad_img, p1, p2, (0, 255, 255), thickness=2)                     cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (0, 255, 255), 2, False)                 elif classlist[i] == 8:                     cv.rectangle(pad_img, p1, p2, (255, 0, 255), thickness=2)                     cv.putText(pad_img, class_name[classlist[i] - 1], p1, font, 0.8, (255, 0, 255), 2, False)                 else:                     pass             cv.imwrite(result_name, pad_img)

if __name__ == '__main__':     main()

 

读取tensorflow.pb,输出节点名,以便确定输入输出:

import tensorflow as tf gf = tf.GraphDef() gf.ParseFromString(open(r'D:\share\ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync_2018_07_03\frozen_inference_graph.pb', 'rb').read()) for n in gf.node: print(n.name + ' ===> ' + n.op)
最新回复(0)