pytorch-retinanet训练自己的数据集

mac2026-03-15  4

1.VOC数据集标注

https://blog.csdn.net/qq_38082979/article/details/102868269

2.VOC2csv

得到annotations.csv

classes.csv

val.csv

# -*- coding:utf-8 -*- import csv import os import glob import sys class PascalVOC2CSV(object): def __init__(self,xml=[], ann_path='./annotations.csv',classes_path='./classes.csv'): ''' :param xml: 所有Pascal VOC的xml文件路径组成的列表 :param ann_path: ann_path :param classes_path: classes_path ''' self.xml = xml self.ann_path = ann_path self.classes_path=classes_path self.label=[] self.annotations=[] self.data_transfer() self.write_file() def data_transfer(self): for num, xml_file in enumerate(self.xml): try: # print(xml_file) # 进度输出 sys.stdout.write('\r>> Converting image %d/%d' % ( num + 1, len(self.xml))) sys.stdout.flush() with open(xml_file, 'r') as fp: for p in fp: if '<filename>' in p: self.filen_ame = p.split('>')[1].split('<')[0] if '<object>' in p: # 类别 d = [next(fp).split('>')[1].split('<')[0] for _ in range(9)] self.supercategory = d[0] if self.supercategory not in self.label: self.label.append(self.supercategory) # 边界框 x1 = int(d[-4]); y1 = int(d[-3]); x2 = int(d[-2]); y2 = int(d[-1]) self.annotations.append([os.path.join('/data/VOCdevkit/VOC2007/JPEGImages',self.filen_ame),x1,y1,x2,y2,self.supercategory]) except: continue sys.stdout.write('\n') sys.stdout.flush() def write_file(self,): with open(self.ann_path, 'w', newline='') as fp: csv_writer = csv.writer(fp, dialect='excel') csv_writer.writerows(self.annotations) class_name=sorted(self.label) class_=[] for num,name in enumerate(class_name): class_.append([name,num]) with open(self.classes_path, 'w', newline='') as fp: csv_writer = csv.writer(fp, dialect='excel') csv_writer.writerows(class_) xml_file = glob.glob('./Annotations/*.xml') PascalVOC2CSV(xml_file)

3.pytorch-retinanet

https://github.com/yhenon/pytorch-retinanet

4.pytorch1.1修改nms

https://github.com/huaifeng1993/NMS

model文件修改

def nms(dets, thresh): "Dispatch to either CPU or GPU NMS implementations.\ Accept dets as tensor""" dets = np.array(dets.cpu()) #return pth_nms(dets, thresh) return gpu_nms(dets, thresh)

 

最新回复(0)