1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
| import torch import torch.nn as nn
class TwoStreamNet(nn.Module): def __init__(self, num_classes=101, backbone='resnet18'): super(TwoStreamNet, self).__init__() self.spatial_net = torchvision.models.resnet18(pretrained=True) self.temporal_net = torchvision.models.resnet18(pretrained=True) self.temporal_net.conv1 = nn.Conv2d(20, 64, kernel_size=7, stride=2, padding=3, bias=False) self.spatial_net.fc = nn.Linear(512, num_classes) self.temporal_net.fc = nn.Linear(512, num_classes)
def forward(self, rgb_frame, flow_stack): spatial_out = self.spatial_net(rgb_frame) temporal_out = self.temporal_net(flow_stack) fused_out = (spatial_out + temporal_out) / 2 return fused_out, spatial_out, temporal_out
|