import torch
import torch.nn.functional as F


class Conv2d():
    
    def __init__(self, kernel, padding=0, stride=1):
        super().__init__()
        self.kernel = kernel
        self.padding = padding
        self.stride = stride
        
    def forward(self, x):
        out_channels, in_channels, h, w = self.kernel.shape
        
        H = x.shape[1]
        W = x.shape[2]
        
        H_pad = H + 2*self.padding
        W_pad = W + 2*self.padding 
        x_pad = torch.zeros(in_channels, H_pad, W_pad)
        x_pad[:, self.padding:self.padding+H, self.padding:self.padding+W] = x
        
        H_out = (H_pad - h) // self.stride + 1
        W_out = (W_pad - w) // self.stride + 1
        
        z = torch.zeros(out_channels, H_out, W_out)
        
        for i in range(H_out):
            for j in range(W_out):
                patch = x_pad[:, i*self.stride:i*self.stride+h, j*self.stride:j*self.stride+w] * self.kernel
                z[:, i, j] = patch.sum(dim=(1, 2, 3))
        
        return z
    
    def __call__(self, *args, **kwargs):
        return self.forward(*args, **kwargs)
    

class Conv1d():
    
    def __init__(self, kernel):
        super().__init__()
        self.kernel = kernel
        
    def forward(self, x):
        return F.conv1d(x[None, None, :], self.kernel[None, None, :])
    
    def calculate_gradients(self, dLdf, x):
        h = self.kernel.shape[0]
        y = F.pad(dLdf, (h-1, h-1))
        
        kernel = torch.empty_like(self.kernel)
        for i in range(h):
            kernel[i] = self.kernel[h-i-1]
            
        dLdx = F.conv1d(y[None, None, :], kernel[None, None, :])        
        dLdk = F.conv1d(x[None, None, :], dLdf[None, None, :])
        
        return dLdx, dLdk
        
        
        
        
        
        
        
        
        
        
        
        
        
        