Hyperspectral-Classification Pytorch 数据集的读取、划分、加载
文章目录
Hyperspectral-Classification Pytorch 数据集的读取、划分、加载数据集读取:流程:代码:main.py:datasets.py:utils.py:
ground truth划分:流程:代码:main.py:utils.py:
生成样本的dataset和dataloader:流程代码:打印信息:
生成DataLoader流程代码:打印信息打印信息
这里只关心
sample的
gt的读取、划分、加载。
不关心其他参数如ignored_labels等。
数据集读取:
流程:
位置函数作用
主程序调用————datasets.pyget_dataset()得到img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, paletteutilis.pyopen_file()读取数据集,包括PaviaU.mat和PaviaU_gt.mat
代码:
main.py:
img
, gt
, LABEL_VALUES
, IGNORED_LABELS
, RGB_BANDS
, palette
= get_dataset
(DATASET
, FOLDER
)
datasets.py:
elif dataset_name
== 'PaviaU':
img
= open_file
(folder
+ 'PaviaU.mat')['paviaU']
rgb_bands
= (55, 41, 12)
gt
= open_file
(folder
+ 'PaviaU_gt.mat')['paviaU_gt']
label_values
= ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees',
'Painted metal sheets', 'Bare Soil', 'Bitumen',
'Self-Blocking Bricks', 'Shadows']
ignored_labels
= [0]
utils.py:
def open_file(dataset
):
_
, ext
= os
.path
.splitext
(dataset
)
ext
= ext
.lower
()
if ext
== '.mat':
return io
.loadmat
(dataset
)
elif ext
== '.tif' or ext
== '.tiff':
return misc
.imread
(dataset
)
elif ext
== '.hdr':
img
= spectral
.open_image
(dataset
)
return img
.load
()
else:
raise ValueError
("Unknown file format: {}".format(ext
))
ground truth划分:
流程:
位置函数作用
主程序调用————utils.pysample_gt()将非ignored_labels通过取索引的方式,随机划分到train_gt和test_gt_spilt.pysklearn.model_selection.train_test_split()根据所给比例,随机划分输入数据
额外说明:
通过gt划分得到的train_gt和test_gt,仍在原图像维度中,只是将取得元素的位置设定为对应的label,而其他位置置零(默认0为ignored_labels)。
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bMhMLvGE-1570353298738)(C:\Users\73416\Desktop\MarkDown\图片文件夹\Hyperspectral-Classification Pytorch 数据集的读取、划分、加载\Train ground truth.jpg)]
代码:
main.py:
train_gt
, test_gt
= sample_gt
(gt
, SAMPLE_PERCENTAGE
, mode
=SAMPLING_MODE
)
……
……
test_gt
, val_gt
= sample_gt
(test_gt
, 0.95, mode
='random')
utils.py:
def sample_gt(gt
, train_size
, mode
='random'):
"""Extract a fixed percentage of samples from an array of labels.
Args:
gt: a 2D array of int labels
percentage: [0, 1] float
Returns:
train_gt, test_gt: 2D arrays of int labels
"""
indices
= np
.nonzero
(gt
)
X
= list(zip(*indices
))
y
= gt
[indices
].ravel
()
train_gt
= np
.zeros_like
(gt
)
test_gt
= np
.zeros_like
(gt
)
if train_size
> 1:
train_size
= int(train_size
)
if mode
== 'random':
train_indices
, test_indices
= sklearn
.model_selection
.train_test_split
(X
, train_size
=train_size
, stratify
=y
)
train_indices
= [list(t
) for t
in zip(*train_indices
)]
test_indices
= [list(t
) for t
in zip(*test_indices
)]
train_gt
[train_indices
] = gt
[train_indices
]
test_gt
[test_indices
] = gt
[test_indices
]
elif mode
== 'fixed':
……
……
……
else:
raise ValueError
("{} sampling is not implemented yet.".format(mode
))
return train_gt
, test_gt
生成样本的dataset和dataloader:
###HyperX的对象train_dataset
流程
位置函数作用
主程序调用————datasets.pyclass HyperX(torch.utils.data.Dataset):Generic class for a hyperspectral scene
代码:
train_dataset
= HyperX
(img
, train_gt
, **hyperparams
)
……
val_dataset
= HyperX
(img
, val_gt
, **hyperparams
)
强调一个疑问
这里生成对象train_dataset和val_dataset的时候,用的dataset是整个的img,但是用到的ground_truth是部分gt(train_gt和val_gt)。
为了解决这个疑问,我做了一个小测试:
for batch_idx
, (data
, target
) in tqdm
(enumerate(data_loader
), total
=len(data_loader
)):
data
, target
= data
.to
(device
), target
.to
(device
)
if batch_idx
% 100 ==0:
print('initial data shape:',data
.shape
)
if 0 in target
:
os
.system
('pause')
在训练网络的时候,会从data_loader中取出training sample,我加入了一个判断代码,如果取出的training sample中有**‘0’**标签,则打断程序的运行。
整体的检验只用一个epoch就可以。
实际运行的结果发现并不会传入标签为0的元素。标签为0的元素表示本身为标签为0或者没有被分到相应的数据集(这里是训练集)。
结论就是,生成高光谱训练样本的对象时,可以传入整张高光谱图像,只在gt上区分数据集类型。
打印信息:
打印对象train_dataset的属性:
运行示例命令:python C:\Users\73416\PycharmProjects\HSIproject\main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda 0。
center_pixel
True
data
[[[0.080875 0.062375 0.058 ... 0.402625 0.40475 0.40625 ]
[0.0755 0.06825 0.065875 ... 0.30525 0.308 0.316 ]
[0.077625 0.09325 0.0695 ... 0.2885 0.293125 0.295125]
...
...
[0.074125 0.048375 0.0535 ... 0.29775 0.300875 0.302875]
[0.074125 0.093875 0.081875 ... 0.289 0.2885 0.286125]
[0.111125 0.09 0.056125 ... 0.302 0.305875 0.310625]]]
……
flip_augmentation
False
ignored_labels
{0}
indices
[[175 4]
[536 248]
[362 26]
...
[369 37]
[295 212]
[511 223]]
label(与像素点对应的label值)
[[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
...
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]
[0 0 0 ... 0 0 0]]
labels(除去ignored_labels的label值)
[4, 1, 1, 4, 4, 1, 4, 4, 4, 1, 1, 1, 4, 4, 4, 4, 1, 1, 4, 4, 4, 1, 1, 4, 1, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 4, 4, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 4, 1,
……
……
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]
mixture_augmentation
False
……
name
PaviaU
patch_size
1
radiation_augmentation
False
……
生成DataLoader
流程
位置函数作用
主程序调用————dataloader.py————
代码:
train_loader
= data
.DataLoader
(train_dataset
,
batch_size
=hyperparams
['batch_size'],
shuffle
=True)
……
val_loader
= data
.DataLoader
(val_dataset
,
batch_size
=hyperparams
['batch_size'])
打印信息
……
batch_sampler
<torch.utils.data.sampler.BatchSampler object at 0x000002474D3FF780>
……
batch_size
100
……
dataset
<datasets.HyperX object at 0x000002474D3FF748>
……
<torch.utils.data.sampler.RandomSampler object at 0x000002474D3FF668>
……
打印信息
……
batch_sampler
<torch.utils.data.sampler.BatchSampler object at 0x000002474D3FF780>
……
batch_size
100
……
dataset
<datasets.HyperX object at 0x000002474D3FF748>
……
<torch.utils.data.sampler.RandomSampler object at 0x000002474D3FF668>
……