| import torch.nn as nn |
|
|
| |
|
|
| from transformers import PreTrainedModel |
|
|
| from .configuration_spice_cnn import SpiceCNNConfig |
|
|
|
|
| class SpiceCNNModelForImageClassification(PreTrainedModel): |
| config_class = SpiceCNNConfig |
|
|
| def __init__(self, config: SpiceCNNConfig): |
| super().__init__(config) |
| layers = [ |
| nn.Conv2d( |
| config.in_channels, 16, kernel_size=config.kernel_size, padding=1 |
| ), |
| nn.BatchNorm2d(16), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=config.pooling_size), |
| nn.Conv2d(16, 32, kernel_size=config.kernel_size, padding=1), |
| nn.BatchNorm2d(32), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=config.pooling_size), |
| nn.Conv2d(32, 64, kernel_size=config.kernel_size, padding=1), |
| nn.BatchNorm2d(64), |
| nn.ReLU(), |
| nn.MaxPool2d(kernel_size=config.pooling_size), |
| nn.Flatten(), |
| nn.Linear(64 * 3 * 3, 128), |
| nn.ReLU(), |
| nn.Dropout(0.5), |
| nn.Linear(128, config.num_classes), |
| ] |
| self.model = nn.Sequential(*layers) |
|
|
| def forward(self, tensor, labels=None): |
| logits = self.model(tensor) |
| if labels is not None: |
| loss_fnc = nn.CrossEntropyLoss() |
| loss = loss_fnc(logits, labels) |
| return {"loss": loss, "logits": logits} |
| return {"logits": logits} |
|
|
|
|
| |
| |
| |
|
|