| import torch |
| import transformers |
| from transformers import Pipeline |
|
|
| try: |
| import orbitals.scope_guard |
| import orbitals.scope_guard.modeling |
| import orbitals.scope_guard.prompting |
| import orbitals.types |
| except ModuleNotFoundError: |
| raise ImportError( |
| "orbitals.scope_guard module not found. Please install it: `pip install orbitals`" |
| ) |
|
|
|
|
| class ScopeGuardPipeline(Pipeline): |
| def __init__( |
| self, |
| model, |
| tokenizer=None, |
| skip_evidences: bool = False, |
| max_new_tokens: int = 1024, |
| do_sample: bool = False, |
| **kwargs, |
| ): |
| if tokenizer is None and isinstance(model, str): |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model) |
| elif isinstance(tokenizer, str): |
| tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) |
|
|
| if isinstance(model, str): |
| model = transformers.AutoModelForCausalLM.from_pretrained( |
| model, dtype="auto", device_map="auto" |
| ) |
|
|
| |
| if tokenizer is not None: |
| tokenizer.padding_side = "left" |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| self.skip_evidences = skip_evidences |
| self.max_new_tokens = max_new_tokens |
| self.do_sample = do_sample |
|
|
| super().__init__(model, tokenizer, **kwargs) |
|
|
| def _sanitize_parameters( |
| self, |
| **kwargs, |
| ): |
| preprocess_kwargs = {} |
| if "skip_evidences" in kwargs or self.skip_evidences: |
| preprocess_kwargs["skip_evidences"] = kwargs.get( |
| "skip_evidences", self.skip_evidences |
| ) |
|
|
| return ( |
| preprocess_kwargs, |
| {}, |
| {}, |
| ) |
|
|
| def preprocess( |
| self, |
| inputs: tuple[ |
| orbitals.scope_guard.modeling.ScopeGuardInput, |
| str | orbitals.types.AIServiceDescription, |
| ], |
| skip_evidences: bool = False, |
| ): |
| conversation, ai_service_description = inputs |
|
|
| model_messages = orbitals.scope_guard.prompting.prepare_messages( |
| conversation, |
| ai_service_description, |
| skip_evidences, |
| ) |
|
|
| text = self.tokenizer.apply_chat_template( |
| model_messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=False, |
| ) |
|
|
| return {"text": text} |
|
|
| def _forward(self, model_inputs): |
| tokenized = self.tokenizer( |
| model_inputs["text"], |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| ).to(self.device) |
|
|
| with torch.inference_mode(): |
| outputs = self.model.generate( |
| **tokenized, |
| max_new_tokens=self.max_new_tokens, |
| do_sample=self.do_sample, |
| ) |
| return { |
| "output_ids": outputs, |
| "input_ids": tokenized["input_ids"], |
| } |
|
|
| def postprocess(self, model_outputs): |
| output_ids = model_outputs["output_ids"] |
| input_ids = model_outputs["input_ids"] |
|
|
| |
| results = [] |
| for i in range(output_ids.shape[0]): |
| |
| generated_ids = output_ids[i][input_ids.shape[1] :] |
| generated_output = self.tokenizer.decode( |
| generated_ids, |
| skip_special_tokens=True, |
| ) |
| results.append({"generated_text": generated_output}) |
|
|
| return results |
|
|