370 lines
16 KiB
Python
370 lines
16 KiB
Python
# ------------------------------------------------------------------------
|
|
# Deformable DETR
|
|
# Copyright (c) 2020 SenseTime. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
|
# ------------------------------------------------------------------------
|
|
# Modified from DETR (https://github.com/facebookresearch/detr)
|
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
# ------------------------------------------------------------------------
|
|
|
|
"""
|
|
This file provides the definition of the convolutional heads used to predict masks, as well as the losses
|
|
"""
|
|
import io
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from PIL import Image
|
|
|
|
import util.box_ops as box_ops
|
|
from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
|
|
|
|
try:
|
|
from panopticapi.utils import id2rgb, rgb2id
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
class DETRsegm(nn.Module):
|
|
def __init__(self, detr, freeze_detr=False):
|
|
super().__init__()
|
|
self.detr = detr
|
|
|
|
if freeze_detr:
|
|
for p in self.parameters():
|
|
p.requires_grad_(False)
|
|
|
|
hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead
|
|
self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0)
|
|
self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)
|
|
|
|
def forward(self, samples: NestedTensor):
|
|
if not isinstance(samples, NestedTensor):
|
|
samples = nested_tensor_from_tensor_list(samples)
|
|
features, pos = self.detr.backbone(samples)
|
|
|
|
bs = features[-1].tensors.shape[0]
|
|
|
|
src, mask = features[-1].decompose()
|
|
src_proj = self.detr.input_proj(src)
|
|
hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1])
|
|
|
|
outputs_class = self.detr.class_embed(hs)
|
|
outputs_coord = self.detr.bbox_embed(hs).sigmoid()
|
|
out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
|
|
if self.detr.aux_loss:
|
|
out["aux_outputs"] = [
|
|
{"pred_logits": a, "pred_boxes": b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])
|
|
]
|
|
|
|
# FIXME h_boxes takes the last one computed, keep this in mind
|
|
bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask)
|
|
|
|
seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors])
|
|
outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1])
|
|
|
|
out["pred_masks"] = outputs_seg_masks
|
|
return out
|
|
|
|
|
|
class MaskHeadSmallConv(nn.Module):
|
|
"""
|
|
Simple convolutional head, using group norm.
|
|
Upsampling is done using a FPN approach
|
|
"""
|
|
|
|
def __init__(self, dim, fpn_dims, context_dim):
|
|
super().__init__()
|
|
|
|
inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
|
|
self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
|
|
self.gn1 = torch.nn.GroupNorm(8, dim)
|
|
self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
|
|
self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
|
|
self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
|
|
self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
|
|
self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
|
|
self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
|
|
self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
|
|
self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
|
|
self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1)
|
|
|
|
self.dim = dim
|
|
|
|
self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
|
|
self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
|
|
self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_uniform_(m.weight, a=1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(self, x, bbox_mask, fpns):
|
|
def expand(tensor, length):
|
|
return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
|
|
|
|
x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
|
|
|
|
x = self.lay1(x)
|
|
x = self.gn1(x)
|
|
x = F.relu(x)
|
|
x = self.lay2(x)
|
|
x = self.gn2(x)
|
|
x = F.relu(x)
|
|
|
|
cur_fpn = self.adapter1(fpns[0])
|
|
if cur_fpn.size(0) != x.size(0):
|
|
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
|
|
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
x = self.lay3(x)
|
|
x = self.gn3(x)
|
|
x = F.relu(x)
|
|
|
|
cur_fpn = self.adapter2(fpns[1])
|
|
if cur_fpn.size(0) != x.size(0):
|
|
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
|
|
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
x = self.lay4(x)
|
|
x = self.gn4(x)
|
|
x = F.relu(x)
|
|
|
|
cur_fpn = self.adapter3(fpns[2])
|
|
if cur_fpn.size(0) != x.size(0):
|
|
cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
|
|
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
|
|
x = self.lay5(x)
|
|
x = self.gn5(x)
|
|
x = F.relu(x)
|
|
|
|
x = self.out_lay(x)
|
|
return x
|
|
|
|
|
|
class MHAttentionMap(nn.Module):
|
|
"""This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
|
|
|
|
def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True):
|
|
super().__init__()
|
|
self.num_heads = num_heads
|
|
self.hidden_dim = hidden_dim
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
|
|
|
|
nn.init.zeros_(self.k_linear.bias)
|
|
nn.init.zeros_(self.q_linear.bias)
|
|
nn.init.xavier_uniform_(self.k_linear.weight)
|
|
nn.init.xavier_uniform_(self.q_linear.weight)
|
|
self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
|
|
|
|
def forward(self, q, k, mask=None):
|
|
q = self.q_linear(q)
|
|
k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
|
|
qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
|
|
kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
|
|
weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
|
|
|
|
if mask is not None:
|
|
weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
|
|
weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
|
|
weights = self.dropout(weights)
|
|
return weights
|
|
|
|
|
|
def dice_loss(inputs, targets, num_boxes):
|
|
"""
|
|
Compute the DICE loss, similar to generalized IOU for masks
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
"""
|
|
inputs = inputs.sigmoid()
|
|
inputs = inputs.flatten(1)
|
|
numerator = 2 * (inputs * targets).sum(1)
|
|
denominator = inputs.sum(-1) + targets.sum(-1)
|
|
loss = 1 - (numerator + 1) / (denominator + 1)
|
|
return loss.sum() / num_boxes
|
|
|
|
|
|
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
|
"""
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
|
positive vs negative examples. Default = -1 (no weighting).
|
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
|
balance easy vs hard examples.
|
|
Returns:
|
|
Loss tensor
|
|
"""
|
|
prob = inputs.sigmoid()
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
p_t = prob * targets + (1 - prob) * (1 - targets)
|
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
|
|
|
if alpha >= 0:
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
loss = alpha_t * loss
|
|
|
|
return loss.mean(1).sum() / num_boxes
|
|
|
|
|
|
class PostProcessSegm(nn.Module):
|
|
def __init__(self, threshold=0.5):
|
|
super().__init__()
|
|
self.threshold = threshold
|
|
|
|
@torch.no_grad()
|
|
def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
|
|
assert len(orig_target_sizes) == len(max_target_sizes)
|
|
max_h, max_w = max_target_sizes.max(0)[0].tolist()
|
|
outputs_masks = outputs["pred_masks"].squeeze(2)
|
|
outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
|
|
outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
|
|
|
|
for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
|
|
img_h, img_w = t[0], t[1]
|
|
results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
|
|
results[i]["masks"] = F.interpolate(
|
|
results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
|
|
).byte()
|
|
|
|
return results
|
|
|
|
|
|
class PostProcessPanoptic(nn.Module):
|
|
"""This class converts the output of the model to the final panoptic result, in the format expected by the
|
|
coco panoptic API """
|
|
|
|
def __init__(self, is_thing_map, threshold=0.85):
|
|
"""
|
|
Parameters:
|
|
is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether
|
|
the class is a thing (True) or a stuff (False) class
|
|
threshold: confidence threshold: segments with confidence lower than this will be deleted
|
|
"""
|
|
super().__init__()
|
|
self.threshold = threshold
|
|
self.is_thing_map = is_thing_map
|
|
|
|
def forward(self, outputs, processed_sizes, target_sizes=None):
|
|
""" This function computes the panoptic prediction from the model's predictions.
|
|
Parameters:
|
|
outputs: This is a dict coming directly from the model. See the model doc for the content.
|
|
processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the
|
|
model, ie the size after data augmentation but before batching.
|
|
target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size
|
|
of each prediction. If left to None, it will default to the processed_sizes
|
|
"""
|
|
if target_sizes is None:
|
|
target_sizes = processed_sizes
|
|
assert len(processed_sizes) == len(target_sizes)
|
|
out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"]
|
|
assert len(out_logits) == len(raw_masks) == len(target_sizes)
|
|
preds = []
|
|
|
|
def to_tuple(tup):
|
|
if isinstance(tup, tuple):
|
|
return tup
|
|
return tuple(tup.cpu().tolist())
|
|
|
|
for cur_logits, cur_masks, cur_boxes, size, target_size in zip(
|
|
out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes
|
|
):
|
|
# we filter empty queries and detection below threshold
|
|
scores, labels = cur_logits.softmax(-1).max(-1)
|
|
keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold)
|
|
cur_scores, cur_classes = cur_logits.softmax(-1).max(-1)
|
|
cur_scores = cur_scores[keep]
|
|
cur_classes = cur_classes[keep]
|
|
cur_masks = cur_masks[keep]
|
|
cur_masks = interpolate(cur_masks[None], to_tuple(size), mode="bilinear").squeeze(0)
|
|
cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep])
|
|
|
|
h, w = cur_masks.shape[-2:]
|
|
assert len(cur_boxes) == len(cur_classes)
|
|
|
|
# It may be that we have several predicted masks for the same stuff class.
|
|
# In the following, we track the list of masks ids for each stuff class (they are merged later on)
|
|
cur_masks = cur_masks.flatten(1)
|
|
stuff_equiv_classes = defaultdict(lambda: [])
|
|
for k, label in enumerate(cur_classes):
|
|
if not self.is_thing_map[label.item()]:
|
|
stuff_equiv_classes[label.item()].append(k)
|
|
|
|
def get_ids_area(masks, scores, dedup=False):
|
|
# This helper function creates the final panoptic segmentation image
|
|
# It also returns the area of the masks that appears on the image
|
|
|
|
m_id = masks.transpose(0, 1).softmax(-1)
|
|
|
|
if m_id.shape[-1] == 0:
|
|
# We didn't detect any mask :(
|
|
m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device)
|
|
else:
|
|
m_id = m_id.argmax(-1).view(h, w)
|
|
|
|
if dedup:
|
|
# Merge the masks corresponding to the same stuff class
|
|
for equiv in stuff_equiv_classes.values():
|
|
if len(equiv) > 1:
|
|
for eq_id in equiv:
|
|
m_id.masked_fill_(m_id.eq(eq_id), equiv[0])
|
|
|
|
final_h, final_w = to_tuple(target_size)
|
|
|
|
seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy()))
|
|
seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST)
|
|
|
|
np_seg_img = (
|
|
torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy()
|
|
)
|
|
m_id = torch.from_numpy(rgb2id(np_seg_img))
|
|
|
|
area = []
|
|
for i in range(len(scores)):
|
|
area.append(m_id.eq(i).sum().item())
|
|
return area, seg_img
|
|
|
|
area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True)
|
|
if cur_classes.numel() > 0:
|
|
# We know filter empty masks as long as we find some
|
|
while True:
|
|
filtered_small = torch.as_tensor(
|
|
[area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device
|
|
)
|
|
if filtered_small.any().item():
|
|
cur_scores = cur_scores[~filtered_small]
|
|
cur_classes = cur_classes[~filtered_small]
|
|
cur_masks = cur_masks[~filtered_small]
|
|
area, seg_img = get_ids_area(cur_masks, cur_scores)
|
|
else:
|
|
break
|
|
|
|
else:
|
|
cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device)
|
|
|
|
segments_info = []
|
|
for i, a in enumerate(area):
|
|
cat = cur_classes[i].item()
|
|
segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a})
|
|
del cur_classes
|
|
|
|
with io.BytesIO() as out:
|
|
seg_img.save(out, format="PNG")
|
|
predictions = {"png_string": out.getvalue(), "segments_info": segments_info}
|
|
preds.append(predictions)
|
|
return preds
|