pytorch魔改data

mac2024-10-01  55

以项目pytorch-deeplab-xception为例:

测试代码:

https://github.com/jfzhang95/pytorch-deeplab-xception/issues/122

def test(self): self.model.eval() self.evaluator.reset() # tbar = tqdm(self.test_loader, desc='\r') for i, sample in enumerate(self.test_loader): image, target = sample['image'], sample['label'] with torch.no_grad(): output = self.model(image) pred = output.data.cpu().numpy() target = target.cpu().numpy()

这里通过枚举函数,生成的只有image和target,

但是image是已经数据增强过的,有可能已经改变,而且经过Totensor()函数和均值方差,

已经不能适应后续我们可视化任务的需要,

这里我们还需要name、oriimg(resize后的原始图像,以帮助我们可视化)

def test(self): self.model.eval() self.evaluator.reset() # tbar = tqdm(self.test_loader, desc='\r') num = len(self.test_loader) for i, sample in enumerate(self.test_loader): image, target = sample['image'], sample['label'] print(i,"/",num) torch.cuda.synchronize() start = time.time() with torch.no_grad(): output = self.model(image) end = time.time() times = (end - start) * 1000 print(times, "ms") torch.cuda.synchronize() pred = output.data.cpu().numpy() pred = np.argmax(pred, axis=1) target = target.cpu().numpy()

原始的dateset代码:

train_set = coco.COCOSegmentation(args, split='train') val_set = coco.COCOSegmentation(args, split='val'

train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs)

from dataloaders.datasets import cityscapes, coco, combine_dbs, pascal, sbd from torch.utils.data import DataLoader def make_data_loader(args, **kwargs): elif args.dataset == 'coco': train_set = coco.COCOSegmentation(args, split='train') val_set = coco.COCOSegmentation(args, split='val') num_class = train_set.NUM_CLASSES train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, **kwargs) test_loader = None return train_loader, val_loader, test_loader, num_class else: raise NotImplementedError

通过COCO处理函数得到COCO的dataset:

import numpy as np import torch from torch.utils.data import Dataset from mypath import Path from tqdm import trange import os from pycocotools.coco import COCO from pycocotools import mask from torchvision import transforms from dataloaders import custom_transforms as tr from PIL import Image, ImageFile ImageFile.LOAD_TRUNCATED_IMAGES = True class COCOSegmentation(Dataset): NUM_CLASSES = 21 CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72] def __init__(self, args, base_dir=Path.db_root_dir('coco'), split='train', year='2017'): super().__init__() ann_file = os.path.join(base_dir, 'annotations/instances_{}{}.json'.format(split, year)) ids_file = os.path.join(base_dir, 'annotations/{}_ids_{}.pth'.format(split, year)) self.img_dir = os.path.join(base_dir, 'images/{}{}'.format(split, year)) self.split = split self.coco = COCO(ann_file) self.coco_mask = mask if os.path.exists(ids_file): self.ids = torch.load(ids_file) else: ids = list(self.coco.imgs.keys()) self.ids = self._preprocess(ids, ids_file) self.args = args def __getitem__(self, index): _img, _target = self._make_img_gt_point_pair(index) sample = {'image': _img, 'label': _target} if self.split == "train": return self.transform_tr(sample) elif self.split == 'val': return self.transform_val(sample) def _make_img_gt_point_pair(self, index): coco = self.coco img_id = self.ids[index] img_metadata = coco.loadImgs(img_id)[0] path = img_metadata['file_name'] _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) _target = Image.fromarray(self._gen_seg_mask( cocotarget, img_metadata['height'], img_metadata['width'])) return _img, _target def _preprocess(self, ids, ids_file): print("Preprocessing mask, this will take a while. " + \ "But don't worry, it only run once for each split.") tbar = trange(len(ids)) new_ids = [] for i in tbar: img_id = ids[i] cocotarget = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id)) img_metadata = self.coco.loadImgs(img_id)[0] mask = self._gen_seg_mask(cocotarget, img_metadata['height'], img_metadata['width']) # more than 1k pixels if (mask > 0).sum() > 1000: new_ids.append(img_id) tbar.set_description('Doing: {}/{}, got {} qualified images'. \ format(i, len(ids), len(new_ids))) print('Found number of qualified images: ', len(new_ids)) torch.save(new_ids, ids_file) return new_ids def _gen_seg_mask(self, target, h, w): mask = np.zeros((h, w), dtype=np.uint8) coco_mask = self.coco_mask for instance in target: rle = coco_mask.frPyObjects(instance['segmentation'], h, w) m = coco_mask.decode(rle) cat = instance['category_id'] if cat in self.CAT_LIST: c = self.CAT_LIST.index(cat) else: continue if len(m.shape) < 3: mask[:, :] += (mask == 0) * (m * c) else: mask[:, :] += (mask == 0) * (((np.sum(m, axis=2)) > 0) * c).astype(np.uint8) return mask def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample) def transform_val(self, sample): composed_transforms = transforms.Compose([ tr.FixScaleCrop(crop_size=self.args.crop_size), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample) def __len__(self): return len(self.ids)

这里使用的数据增强为自己写的函数,不是pytorch中的transforms

但是使用了transforms的容器函数:transforms.Compose()

数据增强函数都有:

class Normalize(object):

class ToTensor(object):

class RandomHorizontalFlip(object):

class RandomRotate(object):

class RandomGaussianBlur(object):

class RandomScaleCrop(object):

class FixScaleCrop(object):

class FixedResize(object):

这里有一个流程,

pytorch-deeplab-xception/train.py

self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader(args, **kwargs) 

pytorch-deeplab-xception/dataloaders/__init__.py

train_set = cityscapes.CityscapesSegmentation(args, split='train') 

pytorch-deeplab-xception/dataloaders/datasets/coco.py

def __getitem__(self, index):

def __getitem__(self, index): _img, _target = self._make_img_gt_point_pair(index) sample = {'image': _img, 'label': _target} if self.split == "train": return self.transform_tr(sample) elif self.split == 'val': return self.transform_val(sample)

 def _make_img_gt_point_pair(self, index):

def _make_img_gt_point_pair(self, index): coco = self.coco img_id = self.ids[index] img_metadata = coco.loadImgs(img_id)[0] path = img_metadata['file_name'] _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) _target = Image.fromarray(self._gen_seg_mask( cocotarget, img_metadata['height'], img_metadata['width'])) return _img, _target

def transform_tr(self, sample): 

def transform_tr(self, sample): composed_transforms = transforms.Compose([ tr.RandomHorizontalFlip(), tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size), tr.RandomGaussianBlur(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), tr.ToTensor()]) return composed_transforms(sample)

 

pytorch-deeplab-xception/dataloaders/custom_transforms.py

class Normalize(object):

class ToTensor(object):

class FixScaleCrop(object):

class FixScaleCrop(object): def __init__(self, crop_size): self.crop_size = crop_size def __call__(self, sample): img = sample['image'] mask = sample['label'] w, h = img.size if w > h: oh = self.crop_size ow = int(1.0 * w * oh / h) else: ow = self.crop_size oh = int(1.0 * h * ow / w) img = img.resize((ow, oh), Image.BILINEAR) mask = mask.resize((ow, oh), Image.NEAREST) # center crop w, h = img.size x1 = int(round((w - self.crop_size) / 2.)) y1 = int(round((h - self.crop_size) / 2.)) img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) return {'image': img, 'label': mask} class Normalize(object): """Normalize a tensor image with mean and standard deviation. Args: mean (tuple): means for each channel. std (tuple): standard deviations for each channel. """ def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): self.mean = mean self.std = std def __call__(self, sample): img = sample['image'] mask = sample['label'] img = np.array(img).astype(np.float32) mask = np.array(mask).astype(np.float32) img /= 255.0 img -= self.mean img /= self.std return {'image': img, 'label': mask} class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): # swap color axis because # numpy image: H x W x C # torch image: C X H X W img = sample['image'] mask = sample['label'] img = np.array(img).astype(np.float32).transpose((2, 0, 1)) mask = np.array(mask).astype(np.float32) img = torch.from_numpy(img).float() mask = torch.from_numpy(mask).float() return {'image': img, 'label': mask}

改进一下,可以返回原图和路径:

import torch import random import numpy as np from PIL import Image, ImageOps, ImageFilter class Normalize(object): """Normalize a tensor image with mean and standard deviation. Args: mean (tuple): means for each channel. std (tuple): standard deviations for each channel. """ def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): self.mean = mean self.std = std def __call__(self, sample): img = sample['image'] mask = sample['label'] img = np.array(img).astype(np.float32) mask = np.array(mask).astype(np.float32) img /= 255.0 img -= self.mean img /= self.std # return {'image': img, # 'label': mask} #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']} return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']} class ToTensor(object): """Convert ndarrays in sample to Tensors.""" def __call__(self, sample): # swap color axis because # numpy image: H x W x C # torch image: C X H X W img = sample['image'] mask = sample['label'] # import cv2 # image1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # target1 = cv2.cvtColor(np.asarray(mask), cv2.COLOR_GRAY2BGR) # cv2.imwrite("./image5.jpg", image1) # cv2.imwrite("./target5.jpg", target1) # # xxx = np.array(img).astype(np.float32) # import copy # xxx1 = copy.deepcopy(xxx) # xxx2 = copy.deepcopy(xxx) # img1 = np.array(xxx1).astype(np.float32).transpose((2, 1, 0)) # img2 = np.array(xxx2).astype(np.float32).transpose((2, 0, 1)) img = np.array(img).astype(np.float32).transpose((2, 0, 1)) mask = np.array(mask).astype(np.float32) img = torch.from_numpy(img).float() mask = torch.from_numpy(mask).float() # import cv2 # image1=img.cpu().numpy() # target1=mask.cpu().numpy() # image1 = image1.transpose(2, 1, 0) # image1 = cv2.cvtColor(image1, cv2.COLOR_RGB2BGR) # target1 = cv2.cvtColor(target1, cv2.COLOR_GRAY2BGR) # cv2.imwrite("./image4.jpg", image1) # cv2.imwrite("./target4.jpg", target1) # return {'image': img, # 'label': mask} ori_image = np.array(sample['ori_image']).astype(np.float32).transpose((2, 0, 1)) ori_image = torch.from_numpy(ori_image).float() #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']} return {'image': img, 'label': mask, 'ori_image': ori_image, 'path': sample['path']} class RandomHorizontalFlip(object): def __call__(self, sample): img = sample['image'] mask = sample['label'] if random.random() < 0.5: img = img.transpose(Image.FLIP_LEFT_RIGHT) mask = mask.transpose(Image.FLIP_LEFT_RIGHT) # return {'image': img, # 'label': mask} return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']} class RandomRotate(object): def __init__(self, degree): self.degree = degree def __call__(self, sample): img = sample['image'] mask = sample['label'] rotate_degree = random.uniform(-1*self.degree, self.degree) img = img.rotate(rotate_degree, Image.BILINEAR) mask = mask.rotate(rotate_degree, Image.NEAREST) return {'image': img, 'label': mask} class RandomGaussianBlur(object): def __call__(self, sample): img = sample['image'] mask = sample['label'] if random.random() < 0.5: img = img.filter(ImageFilter.GaussianBlur( radius=random.random())) # return {'image': img, # 'label': mask} return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']} class RandomScaleCrop(object): def __init__(self, base_size, crop_size, fill=0): self.base_size = base_size self.crop_size = crop_size self.fill = fill def __call__(self, sample): img = sample['image'] mask = sample['label'] # random scale (short edge) short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) w, h = img.size if h > w: ow = short_size oh = int(1.0 * h * ow / w) else: oh = short_size ow = int(1.0 * w * oh / h) img = img.resize((ow, oh), Image.BILINEAR) mask = mask.resize((ow, oh), Image.NEAREST) # pad crop if short_size < self.crop_size: padh = self.crop_size - oh if oh < self.crop_size else 0 padw = self.crop_size - ow if ow < self.crop_size else 0 img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) # random crop crop_size w, h = img.size x1 = random.randint(0, w - self.crop_size) y1 = random.randint(0, h - self.crop_size) img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) #x = mask[mask>1] return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']} #return {'image': img, 'label': mask, 'ori_image': sample['ori_image'], 'path': sample['path']} # return {'image': img, # 'label': mask} class FixScaleCrop(object): def __init__(self, crop_size): self.crop_size = crop_size def __call__(self, sample): img = sample['image'] mask = sample['label'] w, h = img.size if w > h: oh = self.crop_size ow = int(1.0 * w * oh / h) else: ow = self.crop_size oh = int(1.0 * h * ow / w) img = img.resize((ow, oh), Image.BILINEAR) mask = mask.resize((ow, oh), Image.NEAREST) # center crop w, h = img.size x1 = int(round((w - self.crop_size) / 2.)) y1 = int(round((h - self.crop_size) / 2.)) img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) # import cv2 # image1 = cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) # target1 = cv2.cvtColor(np.asarray(mask), cv2.COLOR_GRAY2BGR) # cv2.imwrite("./image3.jpg", image1) # cv2.imwrite("./target3.jpg", target1) # return {'image': img, # 'label': mask, # } #return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']} return {'image': img, 'label': mask, 'ori_image': img, 'path': sample['path']} class FixedResize(object): def __init__(self, size): self.size = (size, size) # size: (h, w) def __call__(self, sample): img = sample['image'] mask = sample['label'] assert img.size == mask.size img = img.resize(self.size, Image.BILINEAR) mask = mask.resize(self.size, Image.NEAREST) return {'image': img, 'label': mask} def __getitem__(self, index): _img, _target, _path = self._make_img_gt_point_pair(index) sample = {'image': _img, 'label': _target, 'ori_image': _img, 'path': _path} if self.split == "train": return self.transform_tr(sample) elif self.split == 'val': return self.transform_val(sample) elif self.split == 'test': X = self.transform_val(sample) # aa = X['image'] # bb = X['label'] # # aa = aa.cpu().numpy() # bb = bb.cpu().numpy() # aa = aa.transpose(2, 1, 0) # image1 = cv2.cvtColor(aa, cv2.COLOR_RGB2BGR) # target1 = cv2.cvtColor(bb, cv2.COLOR_GRAY2BGR) # cv2.imwrite("./image2.jpg", image1) # cv2.imwrite("./target2.jpg", target1) return X def _make_img_gt_point_pair(self, index): coco = self.coco img_id = self.ids[index] img_metadata = coco.loadImgs(img_id)[0] path = img_metadata['file_name'] _path = path.split('.jpg')[0] _img = Image.open(os.path.join(self.img_dir, path)).convert('RGB') cocotarget = coco.loadAnns(coco.getAnnIds(imgIds=img_id)) _target = Image.fromarray(self._gen_seg_mask( cocotarget, img_metadata['height'], img_metadata['width'])) #_targetx = np.asarray(_target) #x = _targetx[_targetx > 1] # image1 = cv2.cvtColor(np.asarray(_img), cv2.COLOR_RGB2BGR) # target1 = cv2.cvtColor(np.asarray(_target), cv2.COLOR_GRAY2BGR) # cv2.imwrite("./image1.jpg", image1) # cv2.imwrite("./target1.jpg", target1) return _img, _target, _path

 

最新回复(0)