| import torch |
| from torch import nn |
| from transformers import PreTrainedModel, PretrainedConfig |
| from safetensors.torch import load_file |
|
|
| |
| from .modeling_clipPT import CLIPVisionTransformer |
| from transformers import CLIPImageProcessor |
|
|
| from transformers import AutoTokenizer |
| |
| from .modeling_qwen2 import Qwen2Model |
|
|
| |
| from .modeling_timer import TimerForPrediction |
|
|
| class MulTiCastTimerConfig(PretrainedConfig): |
| def __init__( |
| self, |
| forecasting_length = None, |
| vision_model_name = None, |
| text_model_name = None, |
| vision_model_prompt_len = None, |
| text_model_prompt_len = None, |
| timer_prompt_len = None, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.forecasting_length = forecasting_length |
| self.vision_model_name = vision_model_name |
| self.text_model_name = text_model_name |
| |
| self.vision_model_prompt_len = vision_model_prompt_len if vision_model_prompt_len is not None else 10 |
| self.text_model_prompt_len = text_model_prompt_len if text_model_prompt_len is not None else 4 |
| self.timer_prompt_len = timer_prompt_len if timer_prompt_len is not None else 4 |
|
|
| class MulTiCastTimerModel(PreTrainedModel): |
| |
| config_class = MulTiCastTimerConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| if config.vision_model_name is None: |
| pass |
| elif config.vision_model_name == 'CLIP': |
| from transformers import AutoModel |
| vision_model = AutoModel.from_pretrained("openai/clip-vit-base-patch32").vision_model |
| state_dict = vision_model.state_dict() |
| state_dict = {k: v.to(torch.bfloat16) for k, v in state_dict.items()} |
| self.vision_model = CLIPVisionTransformer(vision_model.config, config.vision_model_prompt_len) |
| self.vision_model.load_state_dict(state_dict, strict=False) |
| self.processor = CLIPImageProcessor() |
| for name, param in self.vision_model.named_parameters(): |
| if "encoder.prompts" in name: |
| param.requires_grad = True |
| else: |
| param.requires_grad = False |
| else: |
| pass |
| |
| |
| if config.text_model_name is None: |
| pass |
| elif config.text_model_name == 'Qwen': |
| self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-1.5B-Instruct") |
| from transformers import AutoModelForCausalLM |
| text_model = AutoModelForCausalLM.from_pretrained( |
| "Qwen/Qwen2-1.5B-Instruct", |
| torch_dtype=torch.bfloat16, |
| device_map="cpu", |
| attn_implementation="sdpa" |
| ).model |
| state_dict = text_model.state_dict() |
| self.text_model = Qwen2Model(text_model.config, config.text_model_prompt_len) |
| self.text_model.load_state_dict(state_dict, strict=False) |
| for name, param in self.text_model.named_parameters(): |
| if "prompts" in name: |
| param.requires_grad = True |
| else: |
| param.requires_grad = False |
| else: |
| pass |
| |
| |
| from transformers import AutoModelForCausalLM |
| timer = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True) |
| state_dict = timer.state_dict() |
| self.timer = TimerForPrediction(timer.config, config.timer_prompt_len) |
| self.timer.load_state_dict(state_dict, strict=False) |
| for name, param in self.timer.named_parameters(): |
| if "model.prompts" in name: |
| param.requires_grad = True |
| else: |
| param.requires_grad = False |
| |
| |
| if config.vision_model_name is None: |
| pass |
| else: |
| self.vision_interaction_layer = nn.Linear(self.vision_model.config.hidden_size, self.timer.config.hidden_size) |
|
|
| |
| if config.text_model_name is None: |
| pass |
| else: |
| self.text_interaction_layer = nn.Linear(self.text_model.config.hidden_size, self.timer.config.hidden_size) |
| |
| def predict(self, input_ids = None, images = None, texts = None): |
| images = self.processor.preprocess(images)['pixel_values'][0] |
| images = torch.tensor(images) |
| images = images.unsqueeze(0) |
|
|
| if self.config.vision_model_name is None and images is None: |
| vision_embedding = None |
| else: |
| vision_output = self.vision_model(images, output_attentions=True) |
| vision_attentions = vision_output.attentions |
| vision_embedding = vision_output.pooler_output |
| vision_embedding = self.vision_interaction_layer(vision_embedding) |
|
|
| if self.config.text_model_name is None and all(x is None for x in texts): |
| text_embedding = None |
| else: |
| tokenized_texts = self.tokenizer(texts, return_tensors="pt") |
| text_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_texts["input_ids"][0]) |
| text_output = self.text_model(**tokenized_texts, output_attentions=True) |
| text_attentions = text_output.attentions |
| text_embedding = text_output.last_hidden_state[:, 0 , :] |
| text_embedding = self.text_interaction_layer(text_embedding) |
|
|
| out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) |
|
|
| return { |
| "logits": out.logits, |
| "vision_attentions": vision_attentions, |
| "text_tokens": text_tokens, |
| "text_attentions": text_attentions, |
| "time_series_attentions": out.attentions |
| } |
| |
| def forward(self, input_ids = None, images = None, texts = None, labels = None): |
| if self.config.vision_model_name is None and images is None: |
| vision_embedding = None |
| else: |
| vision_embedding = self.vision_model(images) |
| vision_embedding = vision_embedding.pooler_output |
| vision_embedding = self.vision_interaction_layer(vision_embedding) |
|
|
| if self.config.text_model_name is None and all(x is None for x in texts): |
| text_embedding = None |
| else: |
| tokenized_texts = self.tokenizer(texts, return_tensors="pt") |
| text_embedding = self.text_model(**tokenized_texts) |
| text_embedding = text_embedding.last_hidden_state[:, 0 , :] |
| text_embedding = self.text_interaction_layer(text_embedding) |
|
|
| out = self.timer(input_ids=input_ids, vision_embedding=vision_embedding, text_embedding=text_embedding) |
| out = out["logits"] |
|
|
| if labels is not None: |
| if self.config.forecasting_length == out.shape[-1]: |
| loss = torch.mean(torch.square(out-labels)) |
| else: |
| loss = torch.mean(torch.square(out[:, :self.config.forecasting_length]-labels)) |
| else: |
| loss = None |
|
|
| return { |
| "loss": loss, |
| "logits": out |
| } |
| |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| from transformers.utils import cached_file |
| config = MulTiCastTimerConfig.from_pretrained(pretrained_model_name_or_path) |
| model = MulTiCastTimerModel(config) |
| resolved_file = cached_file(pretrained_model_name_or_path, "model.safetensors") |
| state_dict = load_file(resolved_file) |
| model.load_state_dict(state_dict, strict=False) |
|
|
| return model |