基于pytorch计算ssim和ms-ssim

mac2026-06-20  0

使用pytorch计算两组图片的ssim和ms-ssim

首先是SSIM和MS-SSIM类(ssim.py)
import torch import torch.nn.functional as F def _fspecial_gauss_1d(size, sigma): coords = torch.arange(size).to(dtype=torch.float) coords -= size//2 g = torch.exp(-(coords**2) / (2*sigma**2)) g /= g.sum() return g.unsqueeze(0).unsqueeze(0) def gaussian_filter(input, win): N, C, H, W = input.shape out = F.conv2d(input, win, stride=1, padding=0, groups=C) out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C) return out def _ssim(X, Y, win, data_range=1023, size_average=True, full=False): K1 = 0.01 K2 = 0.03 batch, channel, height, width = X.shape compensation = 1.0 C1 = (K1 * data_range)**2 C2 = (K2 * data_range)**2 win = win.to(X.device, dtype=X.dtype) mu1 = gaussian_filter(X, win) mu2 = gaussian_filter(Y, win) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = compensation * ( gaussian_filter(X * X, win) - mu1_sq ) sigma2_sq = compensation * ( gaussian_filter(Y * Y, win) - mu2_sq ) sigma12 = compensation * ( gaussian_filter(X * Y, win) - mu1_mu2 ) cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map if size_average: ssim_val = ssim_map.mean() cs = cs_map.mean() else: ssim_val = ssim_map.mean(-1).mean(-1).mean(-1) # reduce along CHW cs = cs_map.mean(-1).mean(-1).mean(-1) if full: return ssim_val, cs else: return ssim_val def ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False): if len(X.shape) != 4: raise ValueError('Input images must 4-d tensor.') if not X.type() == Y.type(): raise ValueError('Input images must have the same dtype.') if not X.shape == Y.shape: raise ValueError('Input images must have the same dimensions.') if not (win_size % 2 == 1): raise ValueError('Window size must be odd.') win_sigma = win_sigma if win is None: win = _fspecial_gauss_1d(win_size, win_sigma) win = win.repeat(X.shape[1], 1, 1, 1) else: win_size = win.shape[-1] ssim_val, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, full=True) if size_average: ssim_val = ssim_val.mean() cs = cs.mean() if full: return ssim_val, cs else: return ssim_val def ms_ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False, weights=None): if len(X.shape) != 4: raise ValueError('Input images must 4-d tensor.') if not X.type() == Y.type(): raise ValueError('Input images must have the same dtype.') if not X.shape == Y.shape: raise ValueError('Input images must have the same dimensions.') if not (win_size % 2 == 1): raise ValueError('Window size must be odd.') if weights is None: weights = torch.FloatTensor( [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device, dtype=X.dtype) win_sigma = win_sigma if win is None: win = _fspecial_gauss_1d(win_size, win_sigma) win = win.repeat(X.shape[1], 1, 1, 1) else: win_size = win.shape[-1] levels = weights.shape[0] mcs = [] for _ in range(levels): ssim_val, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, full=True) mcs.append(cs) padding = (X.shape[2] % 2, X.shape[3] % 2) X = F.avg_pool2d(X, kernel_size=2, padding=padding) Y = F.avg_pool2d(Y, kernel_size=2, padding=padding) mcs = torch.stack(mcs, dim=0) # mcs, (level, batch) # weights, (level) msssim_val = torch.prod((mcs[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_val ** weights[-1]), dim=0) # (batch, ) if size_average: msssim_val = msssim_val.mean() return msssim_val # Classes to re-use window class SSIM(torch.nn.Module): def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3): super(SSIM, self).__init__() self.win = _fspecial_gauss_1d( win_size, win_sigma).repeat(channel, 1, 1, 1) self.size_average = size_average self.data_range = data_range def forward(self, X, Y): return ssim(X, Y, win=self.win, data_range=self.data_range, size_average=self.size_average) class MS_SSIM(torch.nn.Module): def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3, weights=None): super(MS_SSIM, self).__init__() self.win = _fspecial_gauss_1d( win_size, win_sigma).repeat(channel, 1, 1, 1) self.size_average = size_average self.data_range = data_range self.weights = weights def forward(self, X, Y): return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)

上面的工具类我在pytorch中当做损失函数使用

使用

这里还用到几个方法,我在下面给出

import argparse from tqdm import tqdm import torch from torch.utils.data import DataLoader from util import (map_range, cv2torch, random_tone_map, DirectoryDataset, str2bool)、 #这里我把上面的ssim.py放到了一个文件夹中,所以需要这样导入 from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( '--batch_size', type=int, default=1, help='Batch size.') parser.add_argument( '-d', '--data_root_path_label', default='D:/project_hdr/hdr-expandnet/test_data_hdr', help='Path to hdr data.') parser.add_argument( '-v', '--data_root_path_pre', default='D:/project_hdr/myGan_same_size_wights_noexpand _minD/test_data_ldr/981', help='Path to hdr data.') parser.add_argument( '--num_workers', type=int, default=1, help='Number of data loading workers.') return parser.parse_args() def transformh(hdr): hdr = map_range(hdr) return cv2torch(hdr) def train(opt): # 加载训练集 dataset1 = DirectoryDataset( data_root_path=opt.data_root_path_label, preprocess=transformh) loader1 = DataLoader( dataset1, batch_size=opt.batch_size, num_workers=opt.num_workers,) # 加载训练集 dataset2 = DirectoryDataset( data_root_path=opt.data_root_path_pre, preprocess=transformh) loader2 = DataLoader( dataset2, batch_size=opt.batch_size, num_workers=opt.num_workers,) for (ldr_in, hdr_target) in zip(loader2, loader1): if torch.cuda.is_available(): ldr_in = ldr_in.cuda() hdr_target = hdr_target.cuda() pre = ldr_in real_B = hdr_target ssim_val = ssim(real_B, pre, data_range=1, size_average=True, ) # return (N,) ms_ssim_val = ms_ssim(real_B, pre, data_range=1, size_average=True,) # (N,) rep = (f'ssim_val: {ssim_val},' f'ms_ssim_val: {ms_ssim_val},') tqdm.write(rep) if __name__ == '__main__': opt = parse_args() train(opt)

下面是使用到的几个方法

def map_range(x, low=0, high=1): return np.interp(x, [x.min(), x.max()], [low, high]).astype(x.dtype) var foo = 'bar'; def cv2torch(np_img): rgb = np_img[:, :, (2, 1, 0)] return torch.from_numpy(rgb.swapaxes(1, 2).swapaxes(0, 1))

下面这个类需要opencv

class DirectoryDataset(Dataset): def __init__(self, data_root_path='hdr_data', data_extensions=['.hdr', '.exr'], load_fn=None, preprocess=None): super(DirectoryDataset, self).__init__() data_root_path = process_path(data_root_path) self.file_list = [] for root, _, fnames in sorted(os.walk(data_root_path)): for fname in fnames: if any(fname.lower().endswith(extension) for extension in data_extensions): self.file_list.append(os.path.join(root, fname)) if len(self.file_list) == 0: msg = 'Could not find any files with extensions:\n[{0}]\nin\n{1}' raise RuntimeError( msg.format(', '.join(data_extensions), data_root_path)) self.preprocess = preprocess def __getitem__(self, index): dpoint = cv2.imread( self.file_list[index], flags=cv2.IMREAD_ANYDEPTH + cv2.IMREAD_COLOR) if self.preprocess is not None: dpoint = self.preprocess(dpoint) return dpoint def __len__(self): return len(self.file_list) def process_path(directory, create=False): directory = os.path.expanduser(directory) directory = os.path.normpath(directory) directory = os.path.abspath(directory) if create: try: os.makedirs(directory) except: pass return directory

大概就是这些,msssim代码来源于github : https://github.com/VainF/pytorch-msssim 当然,计算ssim还有一些其他更加简便的方法,比如

from skimage.measure import compare_ssim (score, diff) = compare_ssim(X, Y, full=True) diff = (diff * 255).astype("float32")

但是python计算msssim的就没有了,所以我用的这个单独的方法去计算的

最新回复(0)