【3D detection】CT3D部分代码的理解

paper: Improving 3D Object Detection with Channel-wise Transformer
code:https://github.com/hlsheng1/CT3D

获得box的八个角点坐标

/pcdet/models/roi_heads/ct3d_head.py130行左右

        # corner
        #输入的rois大小为(batch_size,roi的个数,roi信息)(2,128,7)
        #其中7包括roi的x,y,z,l,w,h,方向角
        #最后得到的corner_points是(B, 128, 2*2*2, 3),每个角点的坐标
        corner_points, _ = self.get_global_grid_points_of_roi(rois)  # (BxN, 2x2x2, 3)
        corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1])  # (2, 128, 2x2x2, 3)

下面看一下 self.get_global_grid_points_of_roi(rois)这个函数,同样在class CT3DHead这个类中

  def get_global_grid_points_of_roi(self, rois):
        rois = rois.view(-1, rois.shape[-1]) # (256,7)
        batch_size_rcnn = rois.shape[0]      # (256)
        # 得到了box八个角点到中心点的坐标距离
        # 中心点的三个轴坐标只要在每个轴加上相应的坐标距离
        # 就能得到每一个角点的坐标了。
        local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) #(256,8,3)
        
        #注意!!!
        #上面得到的只考虑了box的大小,并没有考虑角度
        #如果直接拿来计算的话,得到的所有角点都不是真正的,而是平行于坐标轴的!
        #所以还要通过下面的函数将每个roi的角度考虑信息。
        #最后得到真正的角点全局坐标
        global_roi_grid_points = common_utils.rotate_points_along_z(
            local_roi_grid_points.clone(), rois[:, 6]
        ).squeeze(dim=1) #(256, 8 ,3)
        
        #得到中心点的全局坐标
        global_center = rois[:, 0:3].clone() # (256, 3)
        # pdb.set_trace()
		#中心点加坐标距离最终得到每个box八个点的角点
        global_roi_grid_points += global_center.unsqueeze(dim=1) #(256,8,3  )
        
        #(256,8,3)  (256,8,3)
        return global_roi_grid_points, local_roi_grid_points

看一下 self.get_corner_points(rois, batch_size_rcnn)这个函数,同样在class CT3DHead这个类中

 @staticmethod
    def get_corner_points(rois, batch_size_rcnn):
        #得到一个(2, 2, 2) 的tensor,用来在后面求box八个角点的索引 ,这里还是很巧妙的,可以学习一下。
        faked_features = rois.new_ones((2, 2, 2))
        # nonzero() 返回每个这个tensor中非零元素的索引
        # 例如这里建立了一个(2,2,2)的全是1的tensor,每一个元素都是非零的,
        #所以得到的非零元素的索引为 [0,0,0], [0,0,1], [0,1,0], [0,1,1] .....[1,1,1] 共8个
        dense_idx = faked_features.nonzero()  # (8, 3) [x_idx, y_idx, z_idx]
        dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float()  # (256, 2x2x2, 3)
		# 取每一个RoI的长宽高
        local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] # (256,3)
        #这一步是关键,求每个角点在每个轴上相对于roi中心点的距离
        #例如:[0, 0, 1] * [l, w, h] - [l/2,w/2, h/2]
        # = [-l/2, -w/2, h/2]
        #即这个角点到中心点的坐标距离
        #中心点坐标 在x轴减去一半长,在y轴减去一半宽,在z轴加上一般高。就是这个角点的坐标了。
        roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) \
                          - (local_roi_size.unsqueeze(dim=1) / 2)  # (B, 2x2x2, 3)
        return roi_grid_points #(256,8,3)

这个函数得到了box每一个角点到中心点的坐标距离,中心点的三个轴坐标只要在每个轴加上相应的坐标距离,就能得到每一个角点的坐标了。

在无限高的圆柱中随机采样

论文中在RoI中采样点时,将RoI转换为一个高度无限高的圆柱体,并从中随机采样256个点作为RoI的表征。采样方法如下

圆柱体底面的半径为:
圆柱底面半径的公式
代码如下

        num_sample = self.num_points  # 这里是256个点
        src = rois.new_zeros(batch_size, num_rois, num_sample, 4) #(2,128,256,4)

        for bs_idx in range(batch_size): # 每次循环一个batch_size
        	# batch_dict[points]是所有batch_size中所有的点
        	# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)]是第bs_idx个batch中的点
        	# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)][:,1:5]是每个点的坐标和反射强度
        	#cur_points是一个batch中所有点的坐标+反射强度(194165,4)
            cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:5] #(194165,4)
			# 每个batch的roi box 
            cur_batch_boxes = batch_dict['rois'][bs_idx] #(128,7)
            
            # 求半径公式如上图所示 (128)
            cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.2
            
            # 所有点到roi box中心的距离,
            # 共128个RoI,每一个RoI 都计算19165个点到其中心的距离 
            dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) # (128,19165)
            #过滤出半径内的点
            point_mask = (dis <= cur_radiis.unsqueeze(-1))
            # 遍历每一个roi
            for roi_box_idx in range(0, num_rois):
            	# point_mask[roi_box_idx] 是第roi_box_idx个roi的mask
      			#cur_roi_points 这里是(465,4),即这个圆柱roi中有465个点
                cur_roi_points = cur_points[point_mask[roi_box_idx]]
                
				# 如果roi内部的点大于要求采样的数量的话,就随机取256个(源码中num_sample=256)
				# 如果roi内部的点个数等于0,就用0填充256个
				# 如果roi内部点在0到256之间,采样所有点,剩余的用0填充
                if cur_roi_points.shape[0] >= num_sample:
                    random.seed(0)
                    index = np.random.randint(cur_roi_points.shape[0], size=num_sample)
                    cur_roi_points_sample = cur_roi_points[index]

                elif cur_roi_points.shape[0] == 0:
                    cur_roi_points_sample = cur_roi_points.new_zeros(num_sample, 4)

                else:
                    empty_num = num_sample - cur_roi_points.shape[0]
                    add_zeros = cur_roi_points.new_zeros(empty_num, 4)
                    add_zeros = cur_roi_points[0].repeat(empty_num, 1)
                    cur_roi_points_sample = torch.cat([cur_roi_points, add_zeros], dim = 0)
				#这个roi的采样结束,记录到src中
                src[bs_idx, roi_box_idx, :, :] = cur_roi_points_sample


		#采样结束,经过整理得到src,共b个batch,每个batch中128个roi,每个roi取256个点(x,y,z,r)
        src = src.view(batch_size * num_rois, -1, src.shape[-1])  # (b*128, 256, 4)

Embedding和Encoding

论文的主要结构图如下所示:
论文结构图
得到了圆柱体内部的采样点,和box角点,中心点的信息。可以进行embedding了
如下所示
Embedding和Encoding
以一个roi举例,roi内部256个采样点,每一个点都计算自己与roi八个角点以及中心点的距离,再包括自己的反射强度,通过线形层升维。具体公式如下所示:对于点采样点pi, fi是对该点的embedding

在这里插入图片描述

Enbedding的代码如下:

 # src是采样得到的roi内部的点
 src = src.view(batch_size * num_rois, -1, src.shape[-1])  # (b*128, 256, 4)

		#corner_points是上一步得到的roi八个角点坐标(256,8,3) -->(256,24)
        corner_points = corner_points.view(batch_size * num_rois, -1)
        #将每个roi的中心点坐标concat到八个角点坐标上
        # corner_add_center_points (256,24) -->(256,27)
        corner_add_center_points = torch.cat([corner_points, rois.view(-1, rois.shape[-1])[:,:3]], dim = -1)
        # 计算每个点与角点中心点得坐标距离
        #pos_fea(b*roi, num_sample, 27)
        pos_fea = src[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1).repeat(1,num_sample,1)  # 27 维度
        #roi的长宽高 lwh (b*roi, num_sample, 3)
        lwh = rois.view(-1, rois.shape[-1])[:,3:6].unsqueeze(1).repeat(1,num_sample,1)
        # (l*l + w*w + h*h) ** 0.5 
        diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5

		# pos_fea(256,256,27) 每一个点到角点,中心点的 球形坐标距离
        pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))
        # src(256,256,28)
        src = torch.cat([pos_fea, src[:,:,-1].unsqueeze(-1)], dim = -1)
		#src(256,256,256)
        src = self.up_dimension(src)

下面看一下 pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))这个函数:

  def spherical_coordinate(self, src, diag_dist):
  		# src(256,256,27)每个采样点到角点中心点的坐标距离
  		# diag_dist(256,256,1)
  		
        assert (src.shape[-1] == 27)
        device = src.device
        #分别取每个采样点到角点中心点坐标距离
        indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device)  #
        indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) # 
        indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device) 
        src_x = torch.index_select(src, -1, indices_x)  # (256, 256, 9)
        src_y = torch.index_select(src, -1, indices_y)  # (256, 256, 9)
        src_z = torch.index_select(src, -1, indices_z)  # (256, 256, 9)

		#把坐标距离转换成球形坐标距离??
        dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5
        phi = torch.atan(src_y / (src_x + 1e-5))
        the = torch.acos(src_z / (dis + 1e-5))
        dis = dis / diag_dist
        #src(256,256,256)球形坐标距离??
        src = torch.cat([dis, phi, the], dim = -1)
        return src

这里为什么要把坐标系转换成球形坐标,作者在github上解释说,换不换没有什么差别,所以在论文中就没有提到…

Tranformer如下所示


    def forward(self, src, query_embed, pos_embed):
		# src (256,256,256)
		# query_embed (1,256)
		# pos_embed (256,256,256) 位置编码在Emcoder中q,k相加
        bs, n, c = src.shape
        src = src.permute(1, 0, 2)
        pos_embed = pos_embed.permute(1, 0, 2)
        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (1,256,256)
        
        tgt = torch.zeros_like(query_embed) #(1,256,256)
        # memory (256,256,256)
        memory = self.encoder(src, src_key_padding_mask=None, pos=pos_embed)
         #因为在整体代码中,TransformerDecoder的初始化参数return_intermediate设置为True
        #因此,Decoder的输出包含了每层的结果,共有一层,shape是[1,num_querie,batch_size,hidden_dim]
        hs = self.decoder(tgt, memory, memory_key_padding_mask=None,
                          pos=pos_embed, query_pos=query_embed)
        #(1,256,1,256) 
        return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, n)

Encoder模块没有什么可说的,和DETR一样的,共有三层。encoder得到memory(256,256,256),256个roi,每个roi256个采样点,通道数为256。与输入相同
Decoder共一层,与DETR一样,输入的num_query = 1 ,代表每个roi里256个采样点生成一个box 。
k,v来自encoder,每一个roi得到一个box。
唯一不同的是,decoder在计算self-attention时候。公式由标准的
在这里插入图片描述

变成了论文中的:
在这里插入图片描述
再乘以V

代码如下

def attention(query, key,  value):
    dim = query.shape[1]
    scores_1 = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
    scores_2 = torch.einsum('abcd, aced->abcd', key, scores_1)
    prob = torch.nn.functional.softmax(scores_2, dim=-1)
    output = torch.einsum('bnhm,bdhm->bdhn', prob, value)
    return output, prob

版权声明:本文为CSDN博主「printfff」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/printfff/article/details/122234330

printfff

我还没有学会写个人说明!

暂无评论

发表评论

相关推荐

tf-faster-rcnn运行再记录

拿到新的数据,修改roi_rename.py里文件的路径、名称等; 新建resource-Annotations文件夹; 然后打开anaconda、activate Labelimg、labelimg&