YoloV7 标签匹配及 loss 计算解析


🎍本篇文章主要对 YoloV7 的后处理进行详细讲解,YoloV7 除了结构上,对前后处理都进行了改进,其余包括 scheduler、optimizer 等与 YoloV6 都是保持一致的。而前处理中的多数 trick 也可以由其他,例如 X 中的数据增强方式替代。因此我们着重介绍后处理部分。

如上如所示,YoloV7 同大多数单阶段目标检测算法属于密集检测 (dense detection)。上图是一个 7x7 的特征图红色的点是基于特征图的网格点,进行偏移后的点,然后在其上铺设 anchor box,每个点铺设一定数量的 anchor。当然也有直接在网格点上进行铺设的,一般来讲没有太大差别。下面我们开始介绍 v7 后处理,主要分为两部分:标签匹配和 loss 计算。

Label Assignment

📄标签匹配主要分为两步:先是进行粗筛,然后是进行精筛

Find-3-Positive

📑顾名思义,第一步是找到三个正样本,就是对于每一个 GT 找到上图的三个 anchor 作为正样本。首先我们先大概讲一下匹配的规则。

如上图所示,对于每一个网格,会被分为四个部分,绿色点是 GT 中心点,蓝色点是匹配给 GT 的正样本点。首先 GT 中心点所在的网格会被定义为正样本,然后根据中心点在网格的位置来找到另外两个正样本。比如在位置 1 是左上的点会被定义为其正样本,位置 2 是右上,位置 3 是左下,位置 4 是右下。

📚下面是代码的注解和讲解:

# na -- num_anchor nt -- num_targets
# targets [nt, 6] [bs_id, cls, x, y, w, h]
na, nt = self.na, targets.shape[0]
# indices for assignment idx
# anch for assignment anchor
indices, anch = [], []
# 用于将 target 映射到特征图的网格上
gain = torch.ones(7, device=targets.device).long()
# [na, nt] eg. ai[0, :] = [0, ..., 0] ai[1, :] = [1, ..., 1]
ai = torch.arange(na, device=targets.device).float().view(na, 1).repeat(1, nt)
# [na, nt, 7]
targets = torch.cat((targets.repeat(na, 1, 1), ai[:, :, None]), 2)
# 预设 4 个方向上的偏移
g = 0.5
off = torch.tensor([[0, 0],
                    [1, 0], [0, 1], [-1, 0], [0, -1],
                   ], device=targets.device).float() * g
    # 对每个特征层进行匹配
	for i in range(self.nl):
        # 获取特征图 i 的 anchors [3, 2]
        anchors = self.anchors[i]
        # gain = [1, 1, w_i, h_i, w_i, h_i, 1]
        gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]]
        t = targets * gain
        if nt:
            # 获取每个 gt 与 anchor 的长之间的比与宽之间的比 [na, nt, 2]
            r = t[:, :, 4:6] / anchors[:, None]
            # 获得比值都小于阈值 anchor_t 的 mask [na, nt]
            j = torch.max(r, 1. / r).max(2)[0] < self.hyp['anchor_t']
            # 获得满足条件的 gt [n1, 7]
            t = t[j]
            # gt 在该特征图上的坐标
            gxy = t[:, 2:4]
            # 将坐标映射到网格的对称位置,网格的 1 和 4 对称,2 和 3 对称
            gxi = gain[[2, 3]] - gxy
            # 如果在网格位置 1:j、k 均为真值,即找到了左上两个正样本
            # 位置 2:j 为假,k 为真(上) 但此时位置 2 无法缺少一个正样本没有表示
            # [n1] [n1]
            j, k = ((gxy % 1. < g) & (gxy > 1.)).T
            # 由上行就可以看出无法表示 2 的全部正样本,对于位置 2 只能找到一个正样本
            # 3,4 位置正样本无法表示,因此要将 2,3,4 映射到对称位置上
            # 其中位置 1:j -> 1、k -> 1、l -> 0、m -> 0
            # 其中位置 2:j -> 0、k -> 1、l -> 1、m -> 0
            # 其中位置 3:j -> 1、k -> 0、l -> 0、m -> 1
            # 其中位置 4:j -> 0、k -> 0、l -> 1、m -> 1
            l, m = ((gxi % 1. < g) & (gxi > 1.)).T
            # [n1 ,5] 是否选取 gt 中心点所在网格上下左右四个网格为正样本
            # jklm 分别代表左、上、右、下
            j = torch.stack((torch.ones_like(j), j, k, l, m))
            # [5, n1, 7] -> [n2, 7]
            t = t.repeat((5, 1, 1))[j]
            # [5, n1, 2] -> [n2, 2] 获得偏移,以便后面得到三个正样本的坐标
            # 所以这里的 n2 = 3 * n1
            offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j]
        else:
            t = targets[0]
            offsets = 0
        # 正样本的 bs_idx、cls、横纵坐标、宽高
        b, c = t[:, :2].long().T
        gxy = t[:, 2:4]
        gwh = t[:, 4:6]
        # 获得正样本 xy 坐标
        gij = (gxy - offsets).long()
        gi, gj = gij.T
        # 所有正样本匹配的 anchor idx
        a = t[:, 6].long()
        indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1)))
        anch.append(anchors[a])
    
    return indices, anch

OTA

📑在获得初步的正样本后,然后进行第二次筛选,使用的是简易的 OTA,其计算过程是:先计算每个 gt 与所有 anchor 的 iou 并将前十的 iou 之和作为该 gt 匹配的正样本个数,然后计算每一个 gt 和 anchor 计算 cost 来确定最后的正样本,需要注意的是选正样本是基于第一步的基础之上。

def build_targets(self, p, targets, imgs):
    indices, anch = self.find_3_positive(p, targets)

    # [[], [], []]
    matching_bs = [[] for pp in p]
    matching_as = [[] for pp in p]
    matching_gjs = [[] for pp in p]
    matching_gis = [[] for pp in p]
    matching_targets = [[] for pp in p]
    matching_anchs = [[] for pp in p]

    nl = len(p)    

    for batch_idx in range(p[0].shape[0]):
		# 取出第 batch_idx 的 target
        b_idx = targets[:, 0]==batch_idx
        this_target = targets[b_idx]
        if this_target.shape[0] == 0:
            continue
            # 将 target 完全映射回原长度
        txywh = this_target[:, 2:6] * imgs[batch_idx].shape[1]
        txyxy = xywh2xyxy(txywh)

        pxyxys = []
        p_cls = []
        p_obj = []
        from_which_layer = []
        all_b = []
        all_a = []
        all_gj = []
        all_gi = []
        all_anch = []

        for i, pi in enumerate(p):
            # 取出第 i 个 level 的匹配的信息
            b, a, gj, gi = indices[i]
            # 取出该图在 level 匹配到的样本信息 mask
            idx = (b == batch_idx)
            b, a, gj, gi = b[idx], a[idx], gj[idx], gi[idx]                
            all_b.append(b)
            all_a.append(a)
            all_gj.append(gj)
            all_gi.append(gi)
            all_anch.append(anch[i][idx])
            from_which_layer.append(torch.ones(size=(len(b),)) * i)
            # 获得所有正样本的预测 [n, 85]
            fg_pred = pi[b, a, gj, gi]                
            p_obj.append(fg_pred[:, 4:5])
            p_cls.append(fg_pred[:, 5:])

            grid = torch.stack([gi, gj], dim=1)
            pxy = (fg_pred[:, :2].sigmoid() * 2. - 0.5 + grid) * self.stride[i] 
            pwh = (fg_pred[:, 2:4].sigmoid() * 2) ** 2 * anch[i][idx] * self.stride[i]
            pxywh = torch.cat([pxy, pwh], dim=-1)
            pxyxy = xywh2xyxy(pxywh)
            pxyxys.append(pxyxy)
            # [n ,4]
        pxyxys = torch.cat(pxyxys, dim=0)
        if pxyxys.shape[0] == 0:
            continue
        p_obj = torch.cat(p_obj, dim=0)
        p_cls = torch.cat(p_cls, dim=0)
        from_which_layer = torch.cat(from_which_layer, dim=0)
        all_b = torch.cat(all_b, dim=0)
        all_a = torch.cat(all_a, dim=0)
        all_gj = torch.cat(all_gj, dim=0)
        all_gi = torch.cat(all_gi, dim=0)
        all_anch = torch.cat(all_anch, dim=0)
        # [n, n] 获取每一个 tgt 与预测框的 iou
        pair_wise_iou = box_iou(txyxy, pxyxys)
        pair_wise_iou_loss = -torch.log(pair_wise_iou + 1e-8)
        # [n, 10] 取与 tgt iou 排名前 10 的预测框
        top_k, _ = torch.topk(pair_wise_iou, min(10, pair_wise_iou.shape[1]), dim=1)
        # [n] 对于每个 tgt 的正样本个数取为上述前 10 个 iou 的和
        dynamic_ks = torch.clamp(top_k.sum(1).int(), min=1)
        # [n, n, 80]
        gt_cls_per_image = (
            F.one_hot(this_target[:, 1].to(torch.int64), self.nc)
            .float()
            .unsqueeze(1)
            .repeat(1, pxyxys.shape[0], 1)
        )

        num_gt = this_target.shape[0]
        # [n, n, 80]
        cls_preds_ = (
            p_cls.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
            * p_obj.unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
        )

        y = cls_preds_.sqrt_()
        # [n, n] 计算 cls 的代价
        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
            torch.log(y/(1-y)) , gt_cls_per_image, reduction="none"
        ).sum(-1)
        del cls_preds_

        cost = (
            pair_wise_cls_loss
            + 3.0 * pair_wise_iou_loss
        )
        # [n, n]
        matching_matrix = torch.zeros_like(cost)

        for gt_idx in range(num_gt):
            # 对于每个 gt 选择 cost 前 dynamic_ks 个小的预测框作为正样本
            _, pos_idx = torch.topk(
                cost[gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
            )
            matching_matrix[gt_idx][pos_idx] = 1.0

        del top_k, dynamic_ks
        # [n] 计算每个 anchor 匹配到了多少个 gt
        anchor_matching_gt = matching_matrix.sum(0)
        if (anchor_matching_gt > 1).sum() > 0:
            # [n, am] 获取匹配 gt 数大于 1 的 anchor 与 gt 最小的 cost 的 gt 的下标
            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
            # 先全部置为 0
            matching_matrix[:, anchor_matching_gt > 1] *= 0.0
            # 将与 anchor 中 cost 最小对应的 gt 设置为它的正样本
            # 这样做的原因是因为保证每一个 anchor 只匹配到一个 gt
            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
        # [n] 获得 anchor 的 mask,其中 true 为正样本,反之为负样本
        fg_mask_inboxes = matching_matrix.sum(0) > 0.0
        # 获取每个 anchor 匹配到 gt 的下标
        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
        # 获取所有正样本的信息
        from_which_layer = from_which_layer[fg_mask_inboxes]
        all_b = all_b[fg_mask_inboxes]
        all_a = all_a[fg_mask_inboxes]
        all_gj = all_gj[fg_mask_inboxes]
        all_gi = all_gi[fg_mask_inboxes]
        all_anch = all_anch[fg_mask_inboxes]

        this_target = this_target[matched_gt_inds]
        # 进行分层储存,得到不同 level 的正样本
        for i in range(nl):
            layer_idx = from_which_layer == i
            matching_bs[i].append(all_b[layer_idx])
            matching_as[i].append(all_a[layer_idx])
            matching_gjs[i].append(all_gj[layer_idx])
            matching_gis[i].append(all_gi[layer_idx])
            matching_targets[i].append(this_target[layer_idx])
            matching_anchs[i].append(all_anch[layer_idx])

    for i in range(nl):
        if matching_targets[i] != []:
        	matching_bs[i] = torch.cat(matching_bs[i], dim=0)
        	matching_as[i] = torch.cat(matching_as[i], dim=0)
        	matching_gjs[i] = torch.cat(matching_gjs[i], dim=0)
        	matching_gis[i] = torch.cat(matching_gis[i], dim=0)
        	matching_targets[i] = torch.cat(matching_targets[i], dim=0)
        	matching_anchs[i] = torch.cat(matching_anchs[i], dim=0)
        else:
            matching_bs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
            matching_as[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
            matching_gjs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
            matching_gis[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
            matching_targets[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)
            matching_anchs[i] = torch.tensor([], device='cuda:0', dtype=torch.int64)

    return matching_bs, matching_as, matching_gjs, matching_gis, matching_targets, matching_anch

Loss Computation

📑v7 的 loss 计算分为:分类损失、回归损失和目标损失

def __call__(self, p, targets, imgs):
    device = targets.device
    lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
    bs, as_, gjs, gis, targets, anchors = self.build_targets(p, targets, imgs)
    # 每个 level 特征图的宽高
    pre_gen_gains = [torch.tensor(pp.shape, device=device)[[3, 2, 3, 2]] for pp in p] 


    for i, pi in enumerate(p):
        # 取出第 level 的正样本信息
        b, a, gj, gi = bs[i], as_[i], gjs[i], gis[i]
        tobj = torch.zeros_like(pi[..., 0], device=device)

        n = b.shape[0]
        if n:
            # 获得正样本的预测结果
            ps = pi[b, a, gj, gi]

            # 正样本的坐标
            grid = torch.stack([gi, gj], dim=1)
            # 这里得到的是用 stride norm 过后的预测框
            pxy = ps[:, :2].sigmoid() * 2. - 0.5
            pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
            pbox = torch.cat((pxy, pwh), 1)

            selected_tbox = targets[i][:, 2:6] * pre_gen_gains[i]
            # 预测的是离格点距离,因此 gt 需要减去相应格点的坐标
            selected_tbox[:, :2] -= grid
            iou = bbox_iou(pbox.T, selected_tbox, x1y1x2y2=False, CIoU=True)
            lbox += (1.0 - iou).mean()

            # 将 obj label 中正样本位置设置为与 gt 的 iou,其余为 0
            tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype)

            selected_tcls = targets[i][:, 1].long()
            if self.nc > 1:
                # cls loss 只计算了正样本的 loss
                t = torch.full_like(ps[:, 5:], self.cn, device=device)
                t[range(n), selected_tcls] = self.cp
                lcls += self.BCEcls(ps[:, 5:], t)

            # obj 计算为全部样本的 loss
            obji = self.BCEobj(pi[..., 4], tobj)
            lobj += obji * self.balance[i]  # obj loss
            if self.autobalance:
                self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()

    if self.autobalance:
        self.balance = [x / self.balance[self.ssi] for x in self.balance]
    lbox *= self.hyp['box']
    lobj *= self.hyp['obj']
    lcls *= self.hyp['cls']
    bs = tobj.shape[0]  # batch size

    loss = lbox + lobj + lcls
    return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()

☕Yolov7 其他部分如 scheduler 和 optimizer 与 Yolov6 保持一致,另外值得一提的是 v7 的隐性层可以提点 0.5 左右,但是也许仅限 v7 同时对隐性层的 lr 等需要特殊设置。


文章作者: L77
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 L77 !
评论
评论
  目录