文章目录[隐藏]
paper: Improving 3D Object Detection with Channel-wise Transformer
code:https://github.com/hlsheng1/CT3D
获得box的八个角点坐标
在/pcdet/models/roi_heads/ct3d_head.py
中 130行左右
# corner
#输入的rois大小为(batch_size,roi的个数,roi信息)(2,128,7)
#其中7包括roi的x,y,z,l,w,h,方向角
#最后得到的corner_points是(B, 128, 2*2*2, 3),每个角点的坐标
corner_points, _ = self.get_global_grid_points_of_roi(rois) # (BxN, 2x2x2, 3)
corner_points = corner_points.view(batch_size, num_rois, -1, corner_points.shape[-1]) # (2, 128, 2x2x2, 3)
下面看一下 self.get_global_grid_points_of_roi(rois)
这个函数,同样在class CT3DHead
这个类中
def get_global_grid_points_of_roi(self, rois):
rois = rois.view(-1, rois.shape[-1]) # (256,7)
batch_size_rcnn = rois.shape[0] # (256)
# 得到了box八个角点到中心点的坐标距离
# 中心点的三个轴坐标只要在每个轴加上相应的坐标距离
# 就能得到每一个角点的坐标了。
local_roi_grid_points = self.get_corner_points(rois, batch_size_rcnn) #(256,8,3)
#注意!!!
#上面得到的只考虑了box的大小,并没有考虑角度
#如果直接拿来计算的话,得到的所有角点都不是真正的,而是平行于坐标轴的!
#所以还要通过下面的函数将每个roi的角度考虑信息。
#最后得到真正的角点全局坐标
global_roi_grid_points = common_utils.rotate_points_along_z(
local_roi_grid_points.clone(), rois[:, 6]
).squeeze(dim=1) #(256, 8 ,3)
#得到中心点的全局坐标
global_center = rois[:, 0:3].clone() # (256, 3)
# pdb.set_trace()
#中心点加坐标距离最终得到每个box八个点的角点
global_roi_grid_points += global_center.unsqueeze(dim=1) #(256,8,3 )
#(256,8,3) (256,8,3)
return global_roi_grid_points, local_roi_grid_points
看一下 self.get_corner_points(rois, batch_size_rcnn)
这个函数,同样在class CT3DHead
这个类中
@staticmethod
def get_corner_points(rois, batch_size_rcnn):
#得到一个(2, 2, 2) 的tensor,用来在后面求box八个角点的索引 ,这里还是很巧妙的,可以学习一下。
faked_features = rois.new_ones((2, 2, 2))
# nonzero() 返回每个这个tensor中非零元素的索引
# 例如这里建立了一个(2,2,2)的全是1的tensor,每一个元素都是非零的,
#所以得到的非零元素的索引为 [0,0,0], [0,0,1], [0,1,0], [0,1,1] .....[1,1,1] 共8个
dense_idx = faked_features.nonzero() # (8, 3) [x_idx, y_idx, z_idx]
dense_idx = dense_idx.repeat(batch_size_rcnn, 1, 1).float() # (256, 2x2x2, 3)
# 取每一个RoI的长宽高
local_roi_size = rois.view(batch_size_rcnn, -1)[:, 3:6] # (256,3)
#这一步是关键,求每个角点在每个轴上相对于roi中心点的距离
#例如:[0, 0, 1] * [l, w, h] - [l/2,w/2, h/2]
# = [-l/2, -w/2, h/2]
#即这个角点到中心点的坐标距离
#中心点坐标 在x轴减去一半长,在y轴减去一半宽,在z轴加上一般高。就是这个角点的坐标了。
roi_grid_points = dense_idx * local_roi_size.unsqueeze(dim=1) \
- (local_roi_size.unsqueeze(dim=1) / 2) # (B, 2x2x2, 3)
return roi_grid_points #(256,8,3)
这个函数得到了box每一个角点到中心点的坐标距离,中心点的三个轴坐标只要在每个轴加上相应的坐标距离,就能得到每一个角点的坐标了。
在无限高的圆柱中随机采样
论文中在RoI中采样点时,将RoI转换为一个高度无限高的圆柱体,并从中随机采样256个点作为RoI的表征。采样方法如下
圆柱体底面的半径为:
代码如下
num_sample = self.num_points # 这里是256个点
src = rois.new_zeros(batch_size, num_rois, num_sample, 4) #(2,128,256,4)
for bs_idx in range(batch_size): # 每次循环一个batch_size
# batch_dict[points]是所有batch_size中所有的点
# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)]是第bs_idx个batch中的点
# batch_dict[points][(batch_dict['points'][:, 0] == bs_idx)][:,1:5]是每个点的坐标和反射强度
#cur_points是一个batch中所有点的坐标+反射强度(194165,4)
cur_points = batch_dict['points'][(batch_dict['points'][:, 0] == bs_idx)][:,1:5] #(194165,4)
# 每个batch的roi box
cur_batch_boxes = batch_dict['rois'][bs_idx] #(128,7)
# 求半径公式如上图所示 (128)
cur_radiis = torch.sqrt((cur_batch_boxes[:,3]/2) ** 2 + (cur_batch_boxes[:,4]/2) ** 2) * 1.2
# 所有点到roi box中心的距离,
# 共128个RoI,每一个RoI 都计算19165个点到其中心的距离
dis = torch.norm((cur_points[:,:2].unsqueeze(0) - cur_batch_boxes[:,:2].unsqueeze(1).repeat(1,cur_points.shape[0],1)), dim = 2) # (128,19165)
#过滤出半径内的点
point_mask = (dis <= cur_radiis.unsqueeze(-1))
# 遍历每一个roi
for roi_box_idx in range(0, num_rois):
# point_mask[roi_box_idx] 是第roi_box_idx个roi的mask
#cur_roi_points 这里是(465,4),即这个圆柱roi中有465个点
cur_roi_points = cur_points[point_mask[roi_box_idx]]
# 如果roi内部的点大于要求采样的数量的话,就随机取256个(源码中num_sample=256)
# 如果roi内部的点个数等于0,就用0填充256个
# 如果roi内部点在0到256之间,采样所有点,剩余的用0填充
if cur_roi_points.shape[0] >= num_sample:
random.seed(0)
index = np.random.randint(cur_roi_points.shape[0], size=num_sample)
cur_roi_points_sample = cur_roi_points[index]
elif cur_roi_points.shape[0] == 0:
cur_roi_points_sample = cur_roi_points.new_zeros(num_sample, 4)
else:
empty_num = num_sample - cur_roi_points.shape[0]
add_zeros = cur_roi_points.new_zeros(empty_num, 4)
add_zeros = cur_roi_points[0].repeat(empty_num, 1)
cur_roi_points_sample = torch.cat([cur_roi_points, add_zeros], dim = 0)
#这个roi的采样结束,记录到src中
src[bs_idx, roi_box_idx, :, :] = cur_roi_points_sample
#采样结束,经过整理得到src,共b个batch,每个batch中128个roi,每个roi取256个点(x,y,z,r)
src = src.view(batch_size * num_rois, -1, src.shape[-1]) # (b*128, 256, 4)
Embedding和Encoding
论文的主要结构图如下所示:
得到了圆柱体内部的采样点,和box角点,中心点的信息。可以进行embedding了
如下所示
以一个roi举例,roi内部256个采样点,每一个点都计算自己与roi八个角点以及中心点的距离,再包括自己的反射强度,通过线形层升维。具体公式如下所示:对于点采样点pi, fi是对该点的embedding
Enbedding的代码如下:
# src是采样得到的roi内部的点
src = src.view(batch_size * num_rois, -1, src.shape[-1]) # (b*128, 256, 4)
#corner_points是上一步得到的roi八个角点坐标(256,8,3) -->(256,24)
corner_points = corner_points.view(batch_size * num_rois, -1)
#将每个roi的中心点坐标concat到八个角点坐标上
# corner_add_center_points (256,24) -->(256,27)
corner_add_center_points = torch.cat([corner_points, rois.view(-1, rois.shape[-1])[:,:3]], dim = -1)
# 计算每个点与角点中心点得坐标距离
#pos_fea(b*roi, num_sample, 27)
pos_fea = src[:,:,:3].repeat(1,1,9) - corner_add_center_points.unsqueeze(1).repeat(1,num_sample,1) # 27 维度
#roi的长宽高 lwh (b*roi, num_sample, 3)
lwh = rois.view(-1, rois.shape[-1])[:,3:6].unsqueeze(1).repeat(1,num_sample,1)
# (l*l + w*w + h*h) ** 0.5
diag_dist = (lwh[:,:,0]**2 + lwh[:,:,1]**2 + lwh[:,:,2]**2) ** 0.5
# pos_fea(256,256,27) 每一个点到角点,中心点的 球形坐标距离
pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))
# src(256,256,28)
src = torch.cat([pos_fea, src[:,:,-1].unsqueeze(-1)], dim = -1)
#src(256,256,256)
src = self.up_dimension(src)
下面看一下 pos_fea = self.spherical_coordinate(pos_fea, diag_dist = diag_dist.unsqueeze(-1))
这个函数:
def spherical_coordinate(self, src, diag_dist):
# src(256,256,27)每个采样点到角点中心点的坐标距离
# diag_dist(256,256,1)
assert (src.shape[-1] == 27)
device = src.device
#分别取每个采样点到角点中心点坐标距离
indices_x = torch.LongTensor([0,3,6,9,12,15,18,21,24]).to(device) #
indices_y = torch.LongTensor([1,4,7,10,13,16,19,22,25]).to(device) #
indices_z = torch.LongTensor([2,5,8,11,14,17,20,23,26]).to(device)
src_x = torch.index_select(src, -1, indices_x) # (256, 256, 9)
src_y = torch.index_select(src, -1, indices_y) # (256, 256, 9)
src_z = torch.index_select(src, -1, indices_z) # (256, 256, 9)
#把坐标距离转换成球形坐标距离??
dis = (src_x ** 2 + src_y ** 2 + src_z ** 2) ** 0.5
phi = torch.atan(src_y / (src_x + 1e-5))
the = torch.acos(src_z / (dis + 1e-5))
dis = dis / diag_dist
#src(256,256,256)球形坐标距离??
src = torch.cat([dis, phi, the], dim = -1)
return src
这里为什么要把坐标系转换成球形坐标,作者在github上解释说,换不换没有什么差别,所以在论文中就没有提到…
Tranformer如下所示
def forward(self, src, query_embed, pos_embed):
# src (256,256,256)
# query_embed (1,256)
# pos_embed (256,256,256) 位置编码在Emcoder中q,k相加
bs, n, c = src.shape
src = src.permute(1, 0, 2)
pos_embed = pos_embed.permute(1, 0, 2)
query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) # (1,256,256)
tgt = torch.zeros_like(query_embed) #(1,256,256)
# memory (256,256,256)
memory = self.encoder(src, src_key_padding_mask=None, pos=pos_embed)
#因为在整体代码中,TransformerDecoder的初始化参数return_intermediate设置为True
#因此,Decoder的输出包含了每层的结果,共有一层,shape是[1,num_querie,batch_size,hidden_dim]
hs = self.decoder(tgt, memory, memory_key_padding_mask=None,
pos=pos_embed, query_pos=query_embed)
#(1,256,1,256)
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, n)
Encoder模块没有什么可说的,和DETR一样的,共有三层。encoder得到memory(256,256,256)
,256个roi,每个roi256个采样点,通道数为256。与输入相同
Decoder共一层,与DETR一样,输入的num_query = 1 ,代表每个roi里256个采样点生成一个box 。
k,v来自encoder,每一个roi得到一个box。
唯一不同的是,decoder在计算self-attention时候。公式由标准的
变成了论文中的:
再乘以V
代码如下
def attention(query, key, value):
dim = query.shape[1]
scores_1 = torch.einsum('bdhn,bdhm->bhnm', query, key) / dim**.5
scores_2 = torch.einsum('abcd, aced->abcd', key, scores_1)
prob = torch.nn.functional.softmax(scores_2, dim=-1)
output = torch.einsum('bnhm,bdhm->bdhn', prob, value)
return output, prob
版权声明:本文为CSDN博主「printfff」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/printfff/article/details/122234330
暂无评论