Tensorflow版Faster RCNN源码解析(TFFRCNN)(12) gt

mac2022-06-30  36

本blog为github上CharlesShang/TFFRCNN版源码解析系列代码笔记

---------------个人学习笔记---------------

----------------本文作者吴疆--------------

------点击此处链接至博客园原文------

 

通过find in path发现该代码段函数均未被执行(GtDataLayer在train.py中被注释调用,与caffe modules有关),定义函数与roi_data_layer/layer.py类似

""" The data layer used during training to train a Fast R-CNN network. GtDataLayer implements a Caffe Python layer. """

定义了一个GtDataLayer类,未见调用,类内定义了8个函数,分别是:

class GtDataLayer(caffe.Layer): """Fast R-CNN data layer used for training."""

1._shuffle_roidb_inds(self)

将所有图像rois构成的roidb随机打乱顺序,得到self._perm数组和self._cur起始标志,与roi_data_layer/layer.py中类似,被_get_next_minibatch_inds(...)和set_roidb(...)调用

def _shuffle_roidb_inds(self): """Randomly permute the training roidb.""" # 随机打乱顺序 self._perm = np.random.permutation(np.arange(len(self._roidb))) # 起始标志 self._cur = 0

2._get_next_minibatch_inds(self)

获取下一个minibatch的索引(cfg.TRAIN.IMS_PER_BATCH个)并更新self._cur的值,被_get_next_minibatch(...)函数调用

# 获取下一个minibatch的索引(cfg.TRAIN.IMS_PER_BATCH个)并更新self._cur的值 def _get_next_minibatch_inds(self): """Return the roidb indices for the next minibatch.""" if self._cur + cfg.TRAIN.IMS_PER_BATCH >= len(self._roidb): self._shuffle_roidb_inds() db_inds = self._perm[self._cur:self._cur + cfg.TRAIN.IMS_PER_BATCH] self._cur += cfg.TRAIN.IMS_PER_BATCH """ # sample images with gt objects db_inds = np.zeros((cfg.TRAIN.IMS_PER_BATCH), dtype=np.int32) i = 0 while (i < cfg.TRAIN.IMS_PER_BATCH): ind = self._perm[self._cur] num_objs = self._roidb[ind]['boxes'].shape[0] if num_objs != 0: db_inds[i] = ind i += 1 self._cur += 1 if self._cur >= len(self._roidb): self._shuffle_roidb_inds() """ return db_inds

3._get_next_minibatch(self)

获取下一个minibatch作为参数调用minibatch.py中get_minibatch(...)函数,以更新roidb[i]'info_boxes'字段、增加'data'和'parameters'字段组成blobs并返回

def _get_next_minibatch(self): """Return the blobs to be used for the next minibatch.""" # _get_next_minibatch_inds获取下一个minibatch的索引 db_inds = self._get_next_minibatch_inds() minibatch_db = [self._roidb[i] for i in db_inds] # 调用minibatch.py中get_minibatch(...)函数更新roidb[i]'info_boxes'字段、增加'data'和'parameters'字段组成blobs并返回 return get_minibatch(minibatch_db, self._num_classes)

4.def set_roidb(self, roidb)

初始化roidb,获取self._perm和self._cur

# this function is called in training the net def set_roidb(self, roidb): """Set the roidb to be used by this layer during training.""" self._roidb = roidb self._shuffle_roidb_inds()

5.def setup(self, bottom, top)

对GtDataLayer top的reshape处理

def setup(self, bottom, top): """Setup the GtDataLayer.""" # parse the layer parameter string, which must be valid YAML layer_params = yaml.load(self.param_str_) self._num_classes = layer_params['num_classes'] self._name_to_top_map = { 'data': 0, 'info_boxes': 1, 'parameters': 2} # data blob: holds a batch of N images, each with 3 channels # The height and width (100 x 100) are dummy仿造的、假的 values # 默认TRAIN.SCALES_BASE = (0.25, 0.5, 1.0, 2.0, 3.0) num_scale_base = len(cfg.TRAIN.SCALES_BASE) # 未知意义???与caffe module相关 top[0].reshape(num_scale_base, 3, 100, 100) # info boxes blob top[1].reshape(1, 18) # parameters blob num_scale = len(cfg.TRAIN.SCALES) num_aspect = len(cfg.TRAIN.ASPECTS) top[2].reshape(2 + 2*num_scale + 2*num_aspect)

6.def forward(self, bottom, top)

对top赋值操作

def forward(self, bottom, top): """Get blobs and copy them into this layer's top blob vector.""" blobs = self._get_next_minibatch() for blob_name, blob in blobs.iteritems(): # 该值在setup(...)中被定义 # self._name_to_top_map = { # 'data': 0, # 'info_boxes': 1, # 'parameters': 2} top_ind = self._name_to_top_map[blob_name] # Reshape net's input blobs top[top_ind].reshape(*(blob.shape)) # Copy data into net's input blobs top[top_ind].data[...] = blob.astype(np.float32, copy=False)

7.def backward(self, top, propagate_down, bottom)

def backward(self, top, propagate_down, bottom): """This layer does not propagate gradients.""" pass

8.def reshape(self, bottom, top)

def reshape(self, bottom, top): """Reshaping happens during the call to forward.""" pass

转载于:https://www.cnblogs.com/deeplearning1314/p/11325011.html

最新回复(0)