图像分割lable 使用np.equal,np.all,np.stack,三个操作实现one hot编码

mac2026-02-06  0

在图像分割中,lable也是图片,进行训练时,需要将lable image转换成one hot编码的形式。本文仅使用np.equal,np.all,np.stack,三个操作实现one hot编码。 其中label为gt image, label_values为class RGB Values list。 1.np.equal实现把label image每个像素的RGB值与某个class的RGB值进行比对,变成RGB bool值。 2.np.all 把RGB bool值,变成一个bool值,即实现某个class 的label mask。使用for循环,生成所有class的label mask。 3.np.stack实现所有class的label mask的堆叠。最终depth size 为num_classes的数量。

csv file内容如下: name,r,g,b Animal,64,128,64 Archway,192,0,128 Bicyclist,0,128, 192 Bridge,0, 128, 64 Building,128, 0, 0 Car,64, 0, 128 CartLuggagePram,64, 0, 192 Child,192, 128, 64 Column_Pole,192, 192, 128 Fence,64, 64, 128 LaneMkgsDriv,128, 0, 192 LaneMkgsNonDriv,192, 0, 64

def get_label_info(csv_path): """ Retrieve the class names and label values for the selected dataset. Must be in CSV format! # Arguments csv_path: The file path of the class dictionairy # Returns Two lists: one for the class names and the other for the label values """ filename, file_extension = os.path.splitext(csv_path) if not file_extension == ".csv": return ValueError("File is not a CSV!") class_names = [] label_values = [] with open(csv_path, 'r') as csvfile: file_reader = csv.reader(csvfile, delimiter=',') header = next(file_reader) for row in file_reader: class_names.append(row[0]) label_values.append([int(row[1]), int(row[2]), int(row[3])]) # print(class_dict) return class_names, label_values def one_hot_it(label, label_values): """ Convert a segmentation image label array to one-hot format by replacing each pixel value with a vector of length num_classes # Arguments label: The 2D array segmentation image label label_values # Returns A 2D array with the same width and hieght as the input, but with a depth size of num_classes """ semantic_map = [] for colour in label_values: # colour_map = np.full((label.shape[0], label.shape[1], label.shape[2]), colour, dtype=int) equality = np.equal(label, colour) class_map = np.all(equality, axis = -1) semantic_map.append(class_map) semantic_map = np.stack(semantic_map, axis=-1) # print("Time 2 = ", time.time() - st) return semantic_map
最新回复(0)