pytorch-将保存的模型进行新的图片分类

mac2024-03-21  31

from PIL import Image import torchvision.transforms as T from torch.autograd import Variable as V import torch as t model = model.cuda()#导入网络模型 model.eval() model.load_state_dict(t.load('./models/new/model.dat'))#加载训练好的模型文件 import numpy as np def result_(res): if res==0: return 'airplane' elif res==1: return 'ship' elif res==2: return 'bridge' elif res==3: return 'oilcan' elif res==4: return 'build' else: return 'Nan' trans=T.Compose([ T.Scale(325), T.CenterCrop(299), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) trans_gray=T.Compose([ T.Scale(325), T.CenterCrop(299), T.ToTensor(), T.Normalize((0.1307,), (0.3081,)) ]) import csv import cv2 #stag_01_submit # csvFile = open("test_submit.csv", "w") csvFile = open("stag_01_submit.csv", "w") #创建csv文件 writer = csv.writer(csvFile) #创建写的对象 #先写入columns_name writer.writerow(["id","label","ship","bridge","airplane","build","oilcan"]) #写入列的名称 #读入图片 # test_root = './test/' test_root = './stag_01/' img_test=os.listdir(test_root) for i in range(len(img_test)): rd_img = cv2.imread(test_root+img_test[i]) img = Image.open(test_root+img_test[i]) img = img.convert('RGB') print(test_root+img_test[i]) print(rd_img.shape) input=trans(img) input=input.unsqueeze(0)#这里经过转换后输出的input格式是[C,H,W],网络输入还需要增加一维批量大小B #增加一维,输出的img格式为[1,C,H,W] input = V(input.cuda()) score = model(input)#将图片输入网络得到输出 probability = t.nn.functional.softmax(score,dim=1)#计算softmax,即该图片属于各类的概率 max_value,index = t.max(probability,1)#找到最大概率对应的索引号,该图片即为该索引号对应的类别 class_index = result_(index) # print(class_index) probability=np.round(probability.cpu().detach().numpy(),3) # print(probability[0][0]) writer.writerow([img_test[i],class_index,probability[0][1],probability[0][2],probability[0][0],probability[0][4],probability[0][3]]) csvFile.close() import pandas as pd data = pd.read_csv('test_submit.csv')

 

 

 

 

 

 

最新回复(0)