| from transformers import PreTrainedModel, PretrainedConfig |
|
|
| from .config import MoondreamConfig |
| from .moondream import MoondreamModel |
|
|
| |
| from .image_crops import * |
| from .vision import * |
| from .text import * |
| from .region import * |
| from .utils import * |
|
|
|
|
| def extract_question(text): |
| prefix = "<image>\n\nQuestion: " |
| suffix = "\n\nAnswer:" |
| |
| if text.startswith(prefix) and text.endswith(suffix): |
| return text[len(prefix) : -len(suffix)] |
| else: |
| return None |
|
|
|
|
| class HfConfig(PretrainedConfig): |
| _auto_class = "AutoConfig" |
| model_type = "moondream1" |
|
|
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self.config = {} |
|
|
|
|
| class HfMoondream(PreTrainedModel): |
| _auto_class = "AutoModelForCausalLM" |
| config_class = HfConfig |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = MoondreamModel(MoondreamConfig.from_dict(config.config)) |
|
|
| @property |
| def encode_image(self): |
| return self.model.encode_image |
|
|
| @property |
| def query(self): |
| return self.model.query |
|
|
| @property |
| def caption(self): |
| return self.model.caption |
|
|
| @property |
| def detect(self): |
| return self.model.detect |
|
|
| @property |
| def point(self): |
| return self.model.point |
|
|
| @property |
| def detect_gaze(self): |
| return self.model.detect_gaze |
|
|
| def answer_question( |
| self, |
| image_embeds, |
| question, |
| tokenizer=None, |
| chat_history="", |
| result_queue=None, |
| max_new_tokens=256, |
| **kwargs |
| ): |
| answer = self.query(image_embeds, question)["answer"].strip() |
|
|
| if result_queue is not None: |
| result_queue.put(answer) |
| return answer |
|
|
| def batch_answer(self, images, prompts, tokenizer=None, **kwargs): |
| answers = [] |
| for image, prompt in zip(images, prompts): |
| answers.append(self.query(image, prompt)["answer"].strip()) |
| return answers |
|
|
| def _unsupported_exception(self): |
| raise NotImplementedError( |
| "This method is not supported in the latest version of moondream. " |
| "Consider upgrading to the updated API spec, or alternately pin " |
| "to 'revision=2024-08-26'." |
| ) |
|
|
| def generate(self, image_embeds, prompt, tokenizer, max_new_tokens=128, **kwargs): |
| """ |
| Function definition remains unchanged for backwards compatibility. |
| Be aware that tokenizer, max_new_takens, and kwargs are ignored. |
| """ |
| prompt_extracted = extract_question(prompt) |
| if prompt_extracted is not None: |
| answer = self.model.query(image=image_embeds, question=prompt_extracted, stream=False)[ |
| "answer" |
| ] |
| else: |
| image_embeds = self.encode_image(image_embeds) |
| prompt_tokens = torch.tensor( |
| [self.model.tokenizer.encode(prompt).ids], |
| device=self.device, |
| ) |
| def generator(): |
| for token in self.model._generate_text( |
| prompt_tokens, image_embeds.kv_cache, image_embeds.pos, max_new_tokens |
| ): |
| yield token |
| answer = "".join(list(generator())) |
| |
| return [answer] |
|
|
| def get_input_embeddings(self): |
| return super().get_input_embeddings() |
|
|
| def input_embeds(self, *args, **kwargs): |
| self._unsupported_exception() |
|
|