使用pytorch计算两组图片的ssim和ms-ssim
首先是SSIM和MS-SSIM类(ssim.py)
import torch
import torch
.nn
.functional
as F
def
_fspecial_gauss_1d(size
, sigma
):
coords
= torch
.arange(size
).to(dtype
=torch
.float
)
coords
-= size
g
= torch
.exp(-(coords
**2) / (2*sigma
**2))
g
/= g
.sum()
return g
.unsqueeze(0).unsqueeze(0)
def
gaussian_filter(input
, win
):
N, C, H, W = input
.shape
out
= F.conv2d(input
, win
, stride
=1, padding
=0, groups
=C)
out
= F.conv2d(out
, win
.transpose(2, 3), stride
=1, padding
=0, groups
=C)
return out
def
_ssim(X, Y, win
, data_range
=1023, size_average
=True
, full
=False
):
K1 = 0.01
K2 = 0.03
batch
, channel
, height
, width
= X.shape
compensation
= 1.0
C1 = (K1 * data_range
)**2
C2 = (K2 * data_range
)**2
win
= win
.to(X.device
, dtype
=X.dtype
)
mu1
= gaussian_filter(X, win
)
mu2
= gaussian_filter(Y, win
)
mu1_sq
= mu1
.pow(2)
mu2_sq
= mu2
.pow(2)
mu1_mu2
= mu1
* mu2
sigma1_sq
= compensation
* ( gaussian_filter(X * X, win
) - mu1_sq
)
sigma2_sq
= compensation
* ( gaussian_filter(Y * Y, win
) - mu2_sq
)
sigma12
= compensation
* ( gaussian_filter(X * Y, win
) - mu1_mu2
)
cs_map
= (2 * sigma12
+ C2) / (sigma1_sq
+ sigma2_sq
+ C2)
ssim_map
= ((2 * mu1_mu2
+ C1) / (mu1_sq
+ mu2_sq
+ C1)) * cs_map
if size_average
:
ssim_val
= ssim_map
.mean()
cs
= cs_map
.mean()
else:
ssim_val
= ssim_map
.mean(-1).mean(-1).mean(-1) # reduce along
CHW
cs
= cs_map
.mean(-1).mean(-1).mean(-1)
if full
:
return ssim_val
, cs
else:
return ssim_val
def
ssim(X, Y, win_size
=11, win_sigma
=10, win
=None
, data_range
=1, size_average
=True
, full
=False
):
if len(X.shape
) != 4:
raise
ValueError('Input images must 4-d tensor.')
if not
X.type() == Y.type():
raise
ValueError('Input images must have the same dtype.')
if not
X.shape
== Y.shape
:
raise
ValueError('Input images must have the same dimensions.')
if not (win_size
% 2 == 1):
raise
ValueError('Window size must be odd.')
win_sigma
= win_sigma
if win is None
:
win
= _fspecial_gauss_1d(win_size
, win_sigma
)
win
= win
.repeat(X.shape
[1], 1, 1, 1)
else:
win_size
= win
.shape
[-1]
ssim_val
, cs
= _ssim(X, Y,
win
=win
,
data_range
=data_range
,
size_average
=False
,
full
=True
)
if size_average
:
ssim_val
= ssim_val
.mean()
cs
= cs
.mean()
if full
:
return ssim_val
, cs
else:
return ssim_val
def
ms_ssim(X, Y, win_size
=11, win_sigma
=10, win
=None
, data_range
=1, size_average
=True
, full
=False
, weights
=None
):
if len(X.shape
) != 4:
raise
ValueError('Input images must 4-d tensor.')
if not
X.type() == Y.type():
raise
ValueError('Input images must have the same dtype.')
if not
X.shape
== Y.shape
:
raise
ValueError('Input images must have the same dimensions.')
if not (win_size
% 2 == 1):
raise
ValueError('Window size must be odd.')
if weights is None
:
weights
= torch
.FloatTensor(
[0.0448, 0.2856, 0.3001, 0.2363, 0.1333]).to(X.device
, dtype
=X.dtype
)
win_sigma
= win_sigma
if win is None
:
win
= _fspecial_gauss_1d(win_size
, win_sigma
)
win
= win
.repeat(X.shape
[1], 1, 1, 1)
else:
win_size
= win
.shape
[-1]
levels
= weights
.shape
[0]
mcs
= []
for _
in range(levels
):
ssim_val
, cs
= _ssim(X, Y,
win
=win
,
data_range
=data_range
,
size_average
=False
,
full
=True
)
mcs
.append(cs
)
padding
= (X.shape
[2] % 2, X.shape
[3] % 2)
X = F.avg_pool2d(X, kernel_size
=2, padding
=padding
)
Y = F.avg_pool2d(Y, kernel_size
=2, padding
=padding
)
mcs
= torch
.stack(mcs
, dim
=0) # mcs
, (level
, batch
)
# weights
, (level
)
msssim_val
= torch
.prod((mcs
[:-1] ** weights
[:-1].unsqueeze(1))
* (ssim_val
** weights
[-1]), dim
=0) #
(batch
, )
if size_average
:
msssim_val
= msssim_val
.mean()
return msssim_val
# Classes to re
-use window
class SSIM(torch
.nn
.Module
):
def
__init__(self
, win_size
=11, win_sigma
=1.5, data_range
=255, size_average
=True
, channel
=3):
super(SSIM, self
).__init__()
self
.win
= _fspecial_gauss_1d(
win_size
, win_sigma
).repeat(channel
, 1, 1, 1)
self
.size_average
= size_average
self
.data_range
= data_range
def
forward(self
, X, Y):
return ssim(X, Y, win
=self
.win
, data_range
=self
.data_range
, size_average
=self
.size_average
)
class MS_SSIM(torch
.nn
.Module
):
def
__init__(self
, win_size
=11, win_sigma
=1.5, data_range
=255, size_average
=True
, channel
=3, weights
=None
):
super(MS_SSIM, self
).__init__()
self
.win
= _fspecial_gauss_1d(
win_size
, win_sigma
).repeat(channel
, 1, 1, 1)
self
.size_average
= size_average
self
.data_range
= data_range
self
.weights
= weights
def
forward(self
, X, Y):
return ms_ssim(X, Y, win
=self
.win
, size_average
=self
.size_average
, data_range
=self
.data_range
, weights
=self
.weights
)
上面的工具类我在pytorch中当做损失函数使用
使用
这里还用到几个方法,我在下面给出
import argparse
from tqdm
import tqdm
import torch
from torch
.utils
.data
import DataLoader
from util
import (map_range
, cv2torch
, random_tone_map
,
DirectoryDataset
, str2bool
)、
#这里我把上面的ssim
.py放到了一个文件夹中,所以需要这样导入
from pytorch_msssim
import ssim
, ms_ssim
, SSIM, MS_SSIM
def
parse_args():
parser
= argparse
.ArgumentParser()
parser
.add_argument(
'--batch_size', type
=int
, default=1, help
='Batch size.')
parser
.add_argument(
'-d', '--data_root_path_label', default='D:/project_hdr/hdr-expandnet/test_data_hdr', help
='Path to hdr data.')
parser
.add_argument(
'-v', '--data_root_path_pre', default='D:/project_hdr/myGan_same_size_wights_noexpand _minD/test_data_ldr/981',
help
='Path to hdr data.')
parser
.add_argument(
'--num_workers',
type
=int
,
default=1,
help
='Number of data loading workers.')
return parser
.parse_args()
def
transformh(hdr
):
hdr
= map_range(hdr
)
return cv2torch(hdr
)
def
train(opt
):
# 加载训练集
dataset1
= DirectoryDataset(
data_root_path
=opt
.data_root_path_label
, preprocess
=transformh
)
loader1
= DataLoader(
dataset1
,
batch_size
=opt
.batch_size
,
num_workers
=opt
.num_workers
,)
# 加载训练集
dataset2
= DirectoryDataset(
data_root_path
=opt
.data_root_path_pre
, preprocess
=transformh
)
loader2
= DataLoader(
dataset2
,
batch_size
=opt
.batch_size
,
num_workers
=opt
.num_workers
,)
for (ldr_in
, hdr_target
) in zip(loader2
, loader1
):
if torch
.cuda
.is_available():
ldr_in
= ldr_in
.cuda()
hdr_target
= hdr_target
.cuda()
pre
= ldr_in
real_B
= hdr_target
ssim_val
= ssim(real_B
, pre
, data_range
=1, size_average
=True
, ) #
return (N,)
ms_ssim_val
= ms_ssim(real_B
, pre
, data_range
=1, size_average
=True
,) #
(N,)
rep
= (f
'ssim_val: {ssim_val},'
f
'ms_ssim_val: {ms_ssim_val},')
tqdm
.write(rep
)
if __name__
== '__main__':
opt
= parse_args()
train(opt
)
下面是使用到的几个方法
def
map_range(x
, low
=0, high
=1):
return np
.interp(x
, [x
.min(), x
.max()], [low
, high
]).astype(x
.dtype
)
var foo
= 'bar';
def
cv2torch(np_img
):
rgb
= np_img
[:, :, (2, 1, 0)]
return torch
.from_numpy(rgb
.swapaxes(1, 2).swapaxes(0, 1))
下面这个类需要opencv
class DirectoryDataset(Dataset
):
def
__init__(self
,
data_root_path
='hdr_data',
data_extensions
=['.hdr', '.exr'],
load_fn
=None
,
preprocess
=None
):
super(DirectoryDataset
, self
).__init__()
data_root_path
= process_path(data_root_path
)
self
.file_list
= []
for root
, _
, fnames
in sorted(os
.walk(data_root_path
)):
for fname
in fnames
:
if any(fname
.lower().endswith(extension
)
for extension
in data_extensions
):
self
.file_list
.append(os
.path
.join(root
, fname
))
if len(self
.file_list
) == 0:
msg
= 'Could not find any files with extensions:\n[{0}]\nin\n{1}'
raise
RuntimeError(
msg
.format(', '.join(data_extensions
), data_root_path
))
self
.preprocess
= preprocess
def
__getitem__(self
, index
):
dpoint
= cv2
.imread(
self
.file_list
[index
],
flags
=cv2
.IMREAD_ANYDEPTH + cv2
.IMREAD_COLOR)
if self
.preprocess is not None
:
dpoint
= self
.preprocess(dpoint
)
return dpoint
def
__len__(self
):
return len(self
.file_list
)
def
process_path(directory
, create
=False
):
directory
= os
.path
.expanduser(directory
)
directory
= os
.path
.normpath(directory
)
directory
= os
.path
.abspath(directory
)
if create
:
try:
os
.makedirs(directory
)
except
:
pass
return directory
大概就是这些,msssim代码来源于github : https://github.com/VainF/pytorch-msssim 当然,计算ssim还有一些其他更加简便的方法,比如
from skimage
.measure
import compare_ssim
(score
, diff
) = compare_ssim(X, Y, full
=True
)
diff
= (diff
* 255).astype("float32")
但是python计算msssim的就没有了,所以我用的这个单独的方法去计算的