[高光谱] Hyperspectral-Classification Pytorch 数据集的读取、划分、加载

mac2022-07-05  17

Hyperspectral-Classification Pytorch 数据集的读取、划分、加载

文章目录

Hyperspectral-Classification Pytorch 数据集的读取、划分、加载数据集读取:流程:代码:main.py:datasets.py:utils.py: ground truth划分:流程:代码:main.py:utils.py: 生成样本的dataset和dataloader:流程代码:打印信息: 生成DataLoader流程代码:打印信息打印信息 这里只关心 sample的 gt的读取、划分、加载。

不关心其他参数如ignored_labels等。

数据集读取:

流程:

位置函数作用主程序调用————datasets.pyget_dataset()得到img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, paletteutilis.pyopen_file()读取数据集,包括PaviaU.mat和PaviaU_gt.mat

代码:

main.py:
img, gt, LABEL_VALUES, IGNORED_LABELS, RGB_BANDS, palette = get_dataset(DATASET, FOLDER)
datasets.py:
elif dataset_name == 'PaviaU': # Load the image img = open_file(folder + 'PaviaU.mat')['paviaU'] rgb_bands = (55, 41, 12) gt = open_file(folder + 'PaviaU_gt.mat')['paviaU_gt'] label_values = ['Undefined', 'Asphalt', 'Meadows', 'Gravel', 'Trees', 'Painted metal sheets', 'Bare Soil', 'Bitumen', 'Self-Blocking Bricks', 'Shadows'] ignored_labels = [0]
utils.py:
def open_file(dataset): _, ext = os.path.splitext(dataset) ext = ext.lower() if ext == '.mat': # Load Matlab array return io.loadmat(dataset) elif ext == '.tif' or ext == '.tiff': # Load TIFF file return misc.imread(dataset) elif ext == '.hdr': img = spectral.open_image(dataset) return img.load() else: raise ValueError("Unknown file format: {}".format(ext))

ground truth划分:

流程:

位置函数作用主程序调用————utils.pysample_gt()将非ignored_labels通过取索引的方式,随机划分到train_gt和test_gt_spilt.pysklearn.model_selection.train_test_split()根据所给比例,随机划分输入数据

额外说明:

通过gt划分得到的train_gt和test_gt,仍在原图像维度中,只是将取得元素的位置设定为对应的label,而其他位置置零(默认0为ignored_labels)。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bMhMLvGE-1570353298738)(C:\Users\73416\Desktop\MarkDown\图片文件夹\Hyperspectral-Classification Pytorch 数据集的读取、划分、加载\Train ground truth.jpg)]

代码:

main.py:
train_gt, test_gt = sample_gt(gt, SAMPLE_PERCENTAGE, mode=SAMPLING_MODE) …… …… test_gt, val_gt = sample_gt(test_gt, 0.95, mode='random')
utils.py:
def sample_gt(gt, train_size, mode='random'): """Extract a fixed percentage of samples from an array of labels. Args: gt: a 2D array of int labels percentage: [0, 1] float Returns: train_gt, test_gt: 2D arrays of int labels """ indices = np.nonzero(gt) X = list(zip(*indices)) # x,y features (r,c)形式的位置的索引 y = gt[indices].ravel() # classes train_gt = np.zeros_like(gt) test_gt = np.zeros_like(gt) if train_size > 1: train_size = int(train_size) if mode == 'random': train_indices, test_indices = sklearn.model_selection.train_test_split(X, train_size=train_size, stratify=y) train_indices = [list(t) for t in zip(*train_indices)] test_indices = [list(t) for t in zip(*test_indices)] train_gt[train_indices] = gt[train_indices] test_gt[test_indices] = gt[test_indices] elif mode == 'fixed': …… …… …… else: raise ValueError("{} sampling is not implemented yet.".format(mode)) return train_gt, test_gt

生成样本的dataset和dataloader:

###HyperX的对象train_dataset

流程

位置函数作用主程序调用————datasets.pyclass HyperX(torch.utils.data.Dataset):Generic class for a hyperspectral scene

代码:

train_dataset = HyperX(img, train_gt, **hyperparams) …… val_dataset = HyperX(img, val_gt, **hyperparams)

强调一个疑问

这里生成对象train_dataset和val_dataset的时候,用的dataset是整个的img,但是用到的ground_truth是部分gt(train_gt和val_gt)。

为了解决这个疑问,我做了一个小测试:

for batch_idx, (data, target) in tqdm(enumerate(data_loader), total=len(data_loader)): # Load the data into the GPU if required data, target = data.to(device), target.to(device) # ------------自加打印原始的输入维度------------------ # print(type(target)) # print(target.shape) if batch_idx % 100 ==0: print('initial data shape:',data.shape) if 0 in target: os.system('pause') # ------------自加打印原始的输入维度------------------

在训练网络的时候,会从data_loader中取出training sample,我加入了一个判断代码,如果取出的training sample中有**‘0’**标签,则打断程序的运行。

整体的检验只用一个epoch就可以。

实际运行的结果发现并不会传入标签为0的元素。标签为0的元素表示本身为标签为0或者没有被分到相应的数据集(这里是训练集)。

结论就是,生成高光谱训练样本的对象时,可以传入整张高光谱图像,只在gt上区分数据集类型。

打印信息:

打印对象train_dataset的属性:

运行示例命令:python C:\Users\73416\PycharmProjects\HSIproject\main.py --model nn --dataset PaviaU --training_sample 0.1 --cuda 0。

center_pixel True data [[[0.080875 0.062375 0.058 ... 0.402625 0.40475 0.40625 ] [0.0755 0.06825 0.065875 ... 0.30525 0.308 0.316 ] [0.077625 0.09325 0.0695 ... 0.2885 0.293125 0.295125] ... ... [0.074125 0.048375 0.0535 ... 0.29775 0.300875 0.302875] [0.074125 0.093875 0.081875 ... 0.289 0.2885 0.286125] [0.111125 0.09 0.056125 ... 0.302 0.305875 0.310625]]] …… flip_augmentation False ignored_labels {0} indices [[175 4] [536 248] [362 26] ... [369 37] [295 212] [511 223]] label(与像素点对应的label值) [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] labels(除去ignored_labels的label值) [4, 1, 1, 4, 4, 1, 4, 4, 4, 1, 1, 1, 4, 4, 4, 4, 1, 1, 4, 4, 4, 1, 1, 4, 1, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 4, 4, 4, 1, 1, 4, 4, 4, 1, 4, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1, 1, 4, 4, 1, 1, 1, 4, 1, …… …… 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2] mixture_augmentation False …… name PaviaU patch_size 1 radiation_augmentation False ……

生成DataLoader

流程

位置函数作用主程序调用————dataloader.py————

代码:

train_loader = data.DataLoader(train_dataset, batch_size=hyperparams['batch_size'], #pin_memory=hyperparams['device'], shuffle=True) …… val_loader = data.DataLoader(val_dataset, #pin_memory=hyperparams['device'], batch_size=hyperparams['batch_size'])

打印信息

…… batch_sampler <torch.utils.data.sampler.BatchSampler object at 0x000002474D3FF780> …… batch_size 100 …… dataset <datasets.HyperX object at 0x000002474D3FF748> …… <torch.utils.data.sampler.RandomSampler object at 0x000002474D3FF668> ……

打印信息

…… batch_sampler <torch.utils.data.sampler.BatchSampler object at 0x000002474D3FF780> …… batch_size 100 …… dataset <datasets.HyperX object at 0x000002474D3FF748> …… <torch.utils.data.sampler.RandomSampler object at 0x000002474D3FF668> ……
最新回复(0)