import torch
import torch.nn as nn


class Dropout(nn.Module):
    
    def __init__(self, p=0.1):
        super().__init__()
        self.p = p
        
    def forward(self, x):
        if self.training:
            mask = torch.bernoulli((1 - self.p) * torch.ones(x.shape))
            return x * mask / (1 - self.p)
        else:
            return x
    