pytorch的自定义接口是真的方便, 记录一下自己分割数据输入的脚本:
# -*- coding: utf-8 -*- # @Time : 2019/10/31 21:36 # @Author : Yunyun Xu # @Contact : 1443563995@qq.com # @File : MyDatasetReader.py # @Software: Pycharm # @Blog : https://me.csdn.net/xuyunyunaixuexi import os import numpy as np import scipy.misc as m from PIL import Image from torch.utils import data from mypath import Path from torchvision import transforms import custom_transforms as tr class MyEggSegmentation(data.Dataset): #NUM_CLASSES = 19 def __init__(self, args, root = Path.db_root_dir("MyEggs"), split = "train"): self.root = root self.split = split self.args = args self.image_files = {} self.label_files = {} #files = {train:[]} self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) self.annotations_base = os.path.join(self.root, 'gtFine_trainvaltest', 'gtFine', self.split) self.image_files[split] = self.recursive_glob(rootdir=self.images_base, suffix=".png") self.label_files[split] = self.recursive_glob(rootdir = self.annotations_base,suffix = ".png") if not self.image_files[split]: raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) print("Found %d %s images" % (len(self.files[split]), split)) def __len__(self): return len(self.image_files[self.split]) def __getitem__(self, index): img_path = self.image_files[self.split][index].rstrip() lbl_path = self.label_files[self.split][index].rstrip() #将RGBA转为RGB三通道 _img = Image.open(img_path).convert("RGB") #读取索引图 _target = Image.open(lbl_path) sample = {"images":_img, "label":_target} if self.split == "train": return self.transform_tr(sample) if self.split == "val": return self.transform_tr(sample) if self.split == "test": return self.transform_tr(sample) def recursive_glob(self, rootdir = '.', suffix = " "): return [os.path.join(looproot, filename) for looproot, _, filenames in os.walk(rootdir) for filename in filenames if filename.endswith(suffix)] def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample) def transform_val(self, sample): composed_transforms = transforms.Compose([ tr.FixScaleCrop(crop_size=self.args.crop_size), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample) def transform_ts(self, sample): composed_transforms = transforms.Compose([ tr.FixedResize(size=self.args.crop_size), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample)测试了一下,是可以遍历的,证明自定义数据集接口(继承data.Dataset)是正确的:
但是本人也有一个问题, 就是分割网络如果根据自己的数据集大小,去确定crop_size, 还是只能不断的尝试去效果??希望得到大家的解答.........