import torch.nn as nn
from torchvision.models import vgg16 as VGG


class ConvVGG(nn.Module):
    
    def __init__(self):
        super().__init__()
        vgg = VGG(pretrained=True)
        self.features = vgg.features
        self.classifier = nn.Sequential(nn.Conv2d(512, 4096, 7),
                                        nn.ReLU(),
                                        nn.Conv2d(4096, 4096, 1),
                                        nn.ReLU(),
                                        nn.Conv2d(4096, 1000, 1))
        
        self.classifier[0].weight.data = vgg.classifier[0].weight.reshape(4096, 512, 7, 7)
        self.classifier[0].bias.data = vgg.classifier[0].bias
        
        self.classifier[2].weight.data = vgg.classifier[3].weight.reshape(4096, 4096, 1, 1)
        self.classifier[2].bias.data = vgg.classifier[3].bias
        
        self.classifier[4].weight.data = vgg.classifier[6].weight.reshape(1000, 4096, 1, 1)
        self.classifier[4].bias.data = vgg.classifier[6].bias
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x.sum((2, 3))
