|
|
| import os |
| import torch |
| import torch.nn as nn |
| from transformers import AutoModelForImageClassification, PreTrainedModel, hf_hub_download |
| from .configuration_resnet import CustomResNetConfig |
|
|
| class CustomResNetModel(PreTrainedModel): |
| config_class = CustomResNetConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| |
| self.resnet = AutoModelForImageClassification.from_pretrained(config.model_name) |
| |
| |
| in_features = self.resnet.classifier[1].in_features |
| self.resnet.classifier = nn.Sequential( |
| nn.Flatten(), |
| nn.Linear(in_features, config.num_labels) |
| ) |
|
|
| def forward(self, x): |
| return self.resnet(x) |
|
|
| def save_pretrained(self, save_directory, **kwargs): |
| os.makedirs(save_directory, exist_ok=True) |
| torch.save(self.state_dict(), os.path.join(save_directory, "pytorch_model.bin")) |
| self.config.save_pretrained(save_directory) |
|
|
| @classmethod |
| def from_pretrained(cls, repo_id, **kwargs): |
| model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin") |
| config_path = hf_hub_download(repo_id=repo_id, filename="config.json") |
| config = CustomResNetConfig.from_pretrained(config_path) |
| model = cls(config) |
| model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu"))) |
| return model |
|
|