import torch import torchvision from torch import nn def create_effnetb2(seed : int = 42, num_classes : int = 3): #1,2,3 create model , weights and transforms weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT transform = weights.transforms() model = torchvision.models.efficientnet_b2(weights = weights) # frezzing the base layers for param in model.parameters(): param.requires_grad = False #5 updating the clasiifier head for our model torch.manual_seed(seed) model.classifier = nn.Sequential( nn.Dropout(p = 0.3, inplace = True), nn.Linear(in_features = 1408,out_features = num_classes) ) return model, transform