1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
| import torch import torch.nn as nn from torchvision.models import resnet50
class DETR(nn.Module): def __init__(self, num_classes=80, num_queries=100, hidden_dim=256): super().__init__() self.num_classes = num_classes self.num_queries = num_queries self.backbone = resnet50(pretrained=True) self.backbone.fc = nn.Identity() self.conv1x1 = nn.Conv2d(2048, hidden_dim, 1) self.pos_encoder = nn.Parameter(torch.randn(1, hidden_dim, 20, 20)) encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8) decoder_layer = nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=8) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=6) self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=6) self.query_embed = nn.Parameter(torch.randn(num_queries, hidden_dim)) self.class_head = nn.Linear(hidden_dim, num_classes + 1) self.bbox_head = nn.Sequential( nn.Linear(hidden_dim, 4), nn.Sigmoid() )
def forward(self, x): B = x.shape[0] feat = self.backbone.conv1(x) feat = self.backbone.bn1(feat) feat = self.backbone.relu(feat) feat = self.backbone.maxpool(feat) feat = self.backbone.layer1(feat) feat = self.backbone.layer2(feat) feat = self.backbone.layer3(feat) feat = self.backbone.layer4(feat) feat = self.conv1x1(feat) feat = feat.flatten(2).permute(2, 0, 1) pos = self.pos_encoder.flatten(2).permute(2, 0, 1) feat = feat + pos memory = self.encoder(feat) query_embed = self.query_embed.unsqueeze(1).repeat(1, B, 1) tgt = torch.zeros_like(query_embed) hs = self.decoder(tgt, memory) hs = hs.transpose(0, 1) class_logits = self.class_head(hs) bbox_coords = self.bbox_head(hs) return class_logits, bbox_coords
model = DETR(num_classes=80, num_queries=100) x = torch.randn(2, 3, 800, 800) class_logits, bbox_coords = model(x) print(class_logits.shape) print(bbox_coords.shape)
|