| import math |
| import copy |
| import torch |
| from torch.nn import functional as F |
| import torch.nn as nn |
|
|
| from .model_proteinglm_clm import ProteinGLMForGeneration |
|
|
|
|
| class MSAGPT(ProteinGLMForGeneration): |
| def __init__(self, args, transformer=None, **kwargs): |
| super().__init__( |
| args, |
| transformer=transformer, |
| **kwargs |
| ) |
|
|
| @classmethod |
| def add_model_specific_args(cls, parser): |
| group = parser.add_argument_group('MSAGPT-inference', 'MSAGPT inference Configurations') |
| return super().add_model_specific_args(parser) |
|
|
| class FineTuneMSAGPT(MSAGPT): |
| def __init__(self, args, transformer=None, **kwargs): |
| super().__init__( |
| args, |
| transformer=transformer, |
| **kwargs |
| ) |
| pass |