1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
| import torch.nn as nn import torch.nn.functional as F import torchvision from torchsummary import summary
class VGG(nn.Module): """ VGG builder """ def __init__(self, arch: list, num_classes=1000): super(VGG, self).__init__() self.in_channels = 3 self.conv3_64 = self.__make_layer(64, arch[0]) self.conv3_128 = self.__make_layer(128, arch[1]) self.conv3_256 = self.__make_layer(256, arch[2]) self.conv3_512a = self.__make_layer(512, arch[3]) self.conv3_512b = self.__make_layer(512, arch[4]) self.fc1 = nn.Linear(in_features=7 * 7 * 512, out_features=4096) self.bn1 = nn.BatchNorm1d(num_features=4096) self.bn2 = nn.BatchNorm1d(num_features=4096) self.fc2 = nn.Linear(in_features=4096, out_features=4096) self.fc3 = nn.Linear(in_features=4096, out_features=num_classes)
def __make_layer(self, out_channels, num): layers = [] for i in range(num): layers.append(nn.Conv2d(in_channels=self.in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)) layers.append(nn.BatchNorm2d(out_channels)) layers.append(nn.ReLU()) self.in_channels = out_channels return nn.Sequential(*layers)
def forward(self, t): t = self.conv3_64(t) t = F.max_pool2d(t, 2) t = self.conv3_128(t) t = F.max_pool2d(t, 2) t = self.conv3_256(t) t = F.max_pool2d(t, 2) t = self.conv3_512a(t) t = F.max_pool2d(t, 2) t = self.conv3_512b(t) t = F.max_pool2d(t, 2) t = t.view(t.size(0), -1) t = F.relu(self.bn1(self.fc1(t))) t = F.relu(self.bn2(self.fc2(t))) return F.softmax(self.fc3(t), dim=1)
def VGG_11(): return VGG([1, 1, 2, 2, 2], num_classes=1000)
def VGG_13(): return VGG([1, 1, 2, 2, 2], num_classes=1000)
def VGG_16(): return VGG([2, 2, 3, 3, 3], num_classes=1000)
def VGG_19(): return VGG([2, 2, 4, 4, 4], num_classes=1000)
def test(): net = VGG_19() net = net.cuda() summary(net, (3, 224, 224))
if __name__ == '__main__': test()
|