| from transformers.modeling_outputs import SequenceClassifierOutput |
| from transformers.modeling_utils import PreTrainedModel |
| from transformers.configuration_utils import PretrainedConfig |
| import torch |
| from transformers import ZeroShotClassificationPipeline |
|
|
|
|
| class CustomConfig(PretrainedConfig): |
| model_type = "test-zeroshot" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
|
|
|
|
| class CustomModel(PreTrainedModel): |
| config_class = CustomConfig |
|
|
| def __init__(self, config: CustomConfig): |
| super().__init__(config) |
| self.config = config |
| self.embeddings = torch.nn.Embedding(num_embeddings=1, embedding_dim=1) |
|
|
| def forward(self, **kwargs) -> SequenceClassifierOutput: |
|
|
| return SequenceClassifierOutput(logits=torch.tensor([[1, 2, 3]])) |
|
|
|
|
| from transformers.pipelines import PIPELINE_REGISTRY |
|
|
| from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification |
|
|
| if __name__ == "__main__": |
| from transformers import pipeline |
| classifier = pipeline("zero-shot-classification", |
| model="cl-tohoku/bert-base-japanese") |
| from transformers import AutoConfig, AutoModel, AutoModelForImageClassification |
|
|
| CustomConfig.register_for_auto_class() |
| CustomModel.register_for_auto_class("AutoModel") |
|
|
| p = ZeroShotClassificationPipeline(model=CustomModel(CustomConfig()), |
| tokenizer=classifier.tokenizer) |
| from huggingface_hub import Repository |
|
|
| repo = Repository("zero-shot-classification", |
| clone_from="paulhindemith/zero-shot-classification") |
| p.save_pretrained("zero-shot-classification") |
| repo.push_to_hub() |