PulseFocusPlatform/build/lib/ppdet/modeling/losses/jde_loss.py

183 lines
6.9 KiB
Python

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register
__all__ = ['JDEDetectionLoss', 'JDEEmbeddingLoss', 'JDELoss']
@register
class JDEDetectionLoss(nn.Layer):
__shared__ = ['num_classes']
def __init__(self, num_classes=1):
super(JDEDetectionLoss, self).__init__()
self.num_classes = num_classes
def det_loss(self, p_det, anchor, t_conf, t_box):
pshape = paddle.shape(p_det)
pshape.stop_gradient = True
nB, nGh, nGw = pshape[0], pshape[-2], pshape[-1]
nA = len(anchor)
p_det = paddle.reshape(
p_det, [nB, nA, self.num_classes + 5, nGh, nGw]).transpose(
(0, 1, 3, 4, 2))
# 1. loss_conf: cross_entropy
p_conf = p_det[:, :, :, :, 4:6]
p_conf_flatten = paddle.reshape(p_conf, [-1, 2])
t_conf_flatten = t_conf.flatten()
t_conf_flatten = paddle.cast(t_conf_flatten, dtype="int64")
t_conf_flatten.stop_gradient = True
loss_conf = F.cross_entropy(
p_conf_flatten, t_conf_flatten, ignore_index=-1, reduction='mean')
loss_conf.stop_gradient = False
# 2. loss_box: smooth_l1_loss
p_box = p_det[:, :, :, :, :4]
p_box_flatten = paddle.reshape(p_box, [-1, 4])
t_box_flatten = paddle.reshape(t_box, [-1, 4])
fg_inds = paddle.nonzero(t_conf_flatten > 0).flatten()
if fg_inds.numel() > 0:
reg_delta = paddle.gather(p_box_flatten, fg_inds)
reg_target = paddle.gather(t_box_flatten, fg_inds)
else:
reg_delta = paddle.to_tensor([0, 0, 0, 0], dtype='float32')
reg_delta.stop_gradient = False
reg_target = paddle.to_tensor([0, 0, 0, 0], dtype='float32')
reg_target.stop_gradient = True
loss_box = F.smooth_l1_loss(
reg_delta, reg_target, reduction='mean', delta=1.0)
loss_box.stop_gradient = False
return loss_conf, loss_box
def forward(self, det_outs, targets, anchors):
"""
Args:
det_outs (list[Tensor]): output from detection head, each one
is a 4-D Tensor with shape [N, C, H, W].
targets (dict): contains 'im_id', 'gt_bbox', 'gt_ide', 'image',
'im_shape', 'scale_factor' and 'tbox', 'tconf', 'tide' of
each FPN level.
anchors (list[list]): anchor setting of JDE model, N row M col, N is
the anchor levels(FPN levels), M is the anchor scales each
level.
"""
assert len(det_outs) == len(anchors)
loss_confs = []
loss_boxes = []
for i, (p_det, anchor) in enumerate(zip(det_outs, anchors)):
t_conf = targets['tconf{}'.format(i)]
t_box = targets['tbox{}'.format(i)]
loss_conf, loss_box = self.det_loss(p_det, anchor, t_conf, t_box)
loss_confs.append(loss_conf)
loss_boxes.append(loss_box)
return {'loss_confs': loss_confs, 'loss_boxes': loss_boxes}
@register
class JDEEmbeddingLoss(nn.Layer):
def __init__(self, ):
super(JDEEmbeddingLoss, self).__init__()
self.phony = self.create_parameter(shape=[1], dtype="float32")
def emb_loss(self, p_ide, t_conf, t_ide, emb_scale, classifier):
emb_dim = p_ide.shape[1]
p_ide = p_ide.transpose((0, 2, 3, 1))
p_ide_flatten = paddle.reshape(p_ide, [-1, emb_dim])
mask = t_conf > 0
mask = paddle.cast(mask, dtype="int64")
mask.stop_gradient = True
emb_mask = mask.max(1).flatten()
emb_mask_inds = paddle.nonzero(emb_mask > 0).flatten()
emb_mask_inds.stop_gradient = True
# use max(1) to decide the id, TODO: more reseanable strategy
t_ide_flatten = t_ide.max(1).flatten()
t_ide_flatten = paddle.cast(t_ide_flatten, dtype="int64")
valid_inds = paddle.nonzero(t_ide_flatten != -1).flatten()
if emb_mask_inds.numel() == 0 or valid_inds.numel() == 0:
# loss_ide = paddle.to_tensor([0]) # will be error in gradient backward
loss_ide = self.phony * 0 # todo
else:
embedding = paddle.gather(p_ide_flatten, emb_mask_inds)
embedding = emb_scale * F.normalize(embedding)
logits = classifier(embedding)
ide_target = paddle.gather(t_ide_flatten, emb_mask_inds)
loss_ide = F.cross_entropy(
logits, ide_target, ignore_index=-1, reduction='mean')
loss_ide.stop_gradient = False
return loss_ide
def forward(self, ide_outs, targets, emb_scale, classifier):
loss_ides = []
for i, p_ide in enumerate(ide_outs):
t_conf = targets['tconf{}'.format(i)]
t_ide = targets['tide{}'.format(i)]
loss_ide = self.emb_loss(p_ide, t_conf, t_ide, emb_scale,
classifier)
loss_ides.append(loss_ide)
return loss_ides
@register
class JDELoss(nn.Layer):
def __init__(self):
super(JDELoss, self).__init__()
def forward(self, loss_confs, loss_boxes, loss_ides, loss_params_cls,
loss_params_reg, loss_params_ide, targets):
assert len(loss_confs) == len(loss_boxes) == len(loss_ides)
assert len(loss_params_cls) == len(loss_params_reg) == len(
loss_params_ide)
assert len(loss_confs) == len(loss_params_cls)
batchsize = targets['gt_bbox'].shape[0]
nTargets = paddle.nonzero(paddle.sum(targets['gt_bbox'], axis=2)).shape[
0] / batchsize
nTargets = paddle.to_tensor(nTargets, dtype='float32')
nTargets.stop_gradient = True
jde_losses = []
for i, (loss_conf, loss_box, loss_ide, l_conf_p, l_box_p,
l_ide_p) in enumerate(
zip(loss_confs, loss_boxes, loss_ides, loss_params_cls,
loss_params_reg, loss_params_ide)):
jde_loss = l_conf_p(loss_conf) + l_box_p(loss_box) + l_ide_p(
loss_ide)
jde_losses.append(jde_loss)
loss_all = {
"loss_conf": sum(loss_confs),
"loss_box": sum(loss_boxes),
"loss_ide": sum(loss_ides),
"loss": sum(jde_losses),
"nTargets": nTargets,
}
return loss_all