1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
| import torch import torch.nn as nn
class DINO(nn.Module): def __init__(self, student, teacher, embed_dim=256, warmup_teacher_temp=0.04, teacher_temp=0.04, warmup_teacher_temp_epochs=30): super().__init__() self.student = student self.teacher = teacher self.student_head = MLP(student.embed_dim, embed_dim) self.teacher_head = MLP(teacher.embed_dim, embed_dim) self.teacher_temp = teacher_temp self.warmup_teacher_temp = warmup_teacher_temp self.warmup_teacher_temp_epochs = warmup_teacher_temp_epochs for param in teacher.parameters(): param.requires_grad = False
@torch.no_grad() def update_teacher(self, m): for param_q, param_k in zip(self.student.parameters(), self.teacher.parameters()): param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)
def forward(self, global_view, local_view): student_feats = self.student(local_view) student_logits = self.student_head(student_feats) with torch.no_grad(): teacher_feats = self.teacher(global_view) teacher_logits = self.teacher_head(teacher_feats) teacher_logits = teacher_logits - teacher_logits.mean(dim=0, keepdim=True) curr_temp = self.get_teacher_temp(epoch) teacher_probs = torch.softmax(teacher_logits / curr_temp, dim=1) loss = -torch.sum(teacher_probs * torch.log_softmax(student_logits, dim=1), dim=1).mean() return loss
def get_teacher_temp(self, epoch): if epoch < self.warmup_teacher_temp_epochs: return self.warmup_teacher_temp else: return self.teacher_temp
|