| |
| """ |
| Helion-2.5-Rnd Inference Pipeline |
| High-level pipeline for easy model usage |
| """ |
|
|
| import logging |
| import time |
| from typing import Any, Dict, List, Optional, Union |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, StoppingCriteria, StoppingCriteriaList |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class StopOnTokens(StoppingCriteria): |
| """Stop generation when specific tokens are generated""" |
| |
| def __init__(self, stop_token_ids: List[int]): |
| self.stop_token_ids = stop_token_ids |
| |
| def __call__( |
| self, |
| input_ids: torch.LongTensor, |
| scores: torch.FloatTensor, |
| **kwargs |
| ) -> bool: |
| for stop_id in self.stop_token_ids: |
| if input_ids[0][-1] == stop_id: |
| return True |
| return False |
|
|
|
|
| class HelionPipeline: |
| """High-level inference pipeline for Helion model""" |
| |
| def __init__( |
| self, |
| model_path: str, |
| device: str = "cuda", |
| torch_dtype=torch.bfloat16, |
| load_in_8bit: bool = False, |
| trust_remote_code: bool = True |
| ): |
| """ |
| Initialize Helion pipeline |
| |
| Args: |
| model_path: Path to model or HuggingFace ID |
| device: Device to load model on |
| torch_dtype: Torch data type |
| load_in_8bit: Whether to load in 8-bit |
| trust_remote_code: Trust remote code |
| """ |
| logger.info(f"Loading Helion model from {model_path}") |
| |
| self.device = device |
| self.model_path = model_path |
| |
| |
| self.tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=trust_remote_code |
| ) |
| |
| |
| self.model = AutoModelForCausalLM.from_pretrained( |
| model_path, |
| torch_dtype=torch_dtype, |
| device_map="auto" if device == "cuda" else None, |
| load_in_8bit=load_in_8bit, |
| trust_remote_code=trust_remote_code |
| ) |
| |
| if device != "cuda" and not load_in_8bit: |
| self.model = self.model.to(device) |
| |
| self.model.eval() |
| |
| |
| self.stop_token_ids = [ |
| self.tokenizer.eos_token_id, |
| self.tokenizer.convert_tokens_to_ids("<|im_end|>"), |
| ] |
| |
| logger.info("Model loaded successfully") |
| |
| def generate( |
| self, |
| prompt: str, |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| top_p: float = 0.9, |
| top_k: int = 50, |
| repetition_penalty: float = 1.1, |
| do_sample: bool = True, |
| num_return_sequences: int = 1, |
| **kwargs |
| ) -> Union[str, List[str]]: |
| """ |
| Generate text from prompt |
| |
| Args: |
| prompt: Input prompt |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| top_p: Nucleus sampling parameter |
| top_k: Top-k sampling parameter |
| repetition_penalty: Repetition penalty |
| do_sample: Whether to sample |
| num_return_sequences: Number of sequences to return |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Generated text or list of texts |
| """ |
| |
| inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| truncation=True, |
| max_length=self.model.config.max_position_embeddings |
| ).to(self.device) |
| |
| |
| stopping_criteria = StoppingCriteriaList([ |
| StopOnTokens(self.stop_token_ids) |
| ]) |
| |
| |
| with torch.no_grad(): |
| start_time = time.time() |
| |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| repetition_penalty=repetition_penalty, |
| do_sample=do_sample, |
| num_return_sequences=num_return_sequences, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=self.tokenizer.pad_token_id, |
| **kwargs |
| ) |
| |
| generation_time = time.time() - start_time |
| |
| |
| generated_texts = [] |
| for output in outputs: |
| text = self.tokenizer.decode( |
| output[inputs['input_ids'].shape[1]:], |
| skip_special_tokens=True |
| ) |
| generated_texts.append(text.strip()) |
| |
| logger.info(f"Generated {len(generated_texts)} sequences in {generation_time:.2f}s") |
| |
| if num_return_sequences == 1: |
| return generated_texts[0] |
| return generated_texts |
| |
| def chat( |
| self, |
| messages: List[Dict[str, str]], |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| **kwargs |
| ) -> str: |
| """ |
| Chat completion |
| |
| Args: |
| messages: List of message dictionaries |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Assistant response |
| """ |
| |
| prompt = self._format_chat_prompt(messages) |
| |
| |
| response = self.generate( |
| prompt, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| **kwargs |
| ) |
| |
| return response |
| |
| def _format_chat_prompt(self, messages: List[Dict[str, str]]) -> str: |
| """Format messages into chat prompt""" |
| formatted = "" |
| |
| for msg in messages: |
| role = msg.get('role', 'user') |
| content = msg.get('content', '') |
| formatted += f"<|im_start|>{role}\n{content}<|im_end|>\n" |
| |
| formatted += "<|im_start|>assistant\n" |
| return formatted |
| |
| def batch_generate( |
| self, |
| prompts: List[str], |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| batch_size: int = 4, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Generate for multiple prompts in batches |
| |
| Args: |
| prompts: List of input prompts |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| batch_size: Batch size for processing |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| List of generated texts |
| """ |
| all_outputs = [] |
| |
| for i in range(0, len(prompts), batch_size): |
| batch = prompts[i:i + batch_size] |
| |
| |
| inputs = self.tokenizer( |
| batch, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.model.config.max_position_embeddings |
| ).to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| pad_token_id=self.tokenizer.pad_token_id, |
| **kwargs |
| ) |
| |
| |
| for j, output in enumerate(outputs): |
| text = self.tokenizer.decode( |
| output[inputs['input_ids'][j].shape[0]:], |
| skip_special_tokens=True |
| ) |
| all_outputs.append(text.strip()) |
| |
| logger.info(f"Generated {len(all_outputs)} outputs") |
| return all_outputs |
| |
| def stream_generate( |
| self, |
| prompt: str, |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| **kwargs |
| ): |
| """ |
| Stream generation token by token |
| |
| Args: |
| prompt: Input prompt |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| **kwargs: Additional generation parameters |
| |
| Yields: |
| Generated tokens |
| """ |
| inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device) |
| input_length = inputs['input_ids'].shape[1] |
| |
| stopping_criteria = StoppingCriteriaList([ |
| StopOnTokens(self.stop_token_ids) |
| ]) |
| |
| with torch.no_grad(): |
| for _ in range(max_new_tokens): |
| outputs = self.model.generate( |
| **inputs, |
| max_new_tokens=1, |
| temperature=temperature, |
| stopping_criteria=stopping_criteria, |
| pad_token_id=self.tokenizer.pad_token_id, |
| **kwargs |
| ) |
| |
| new_token_id = outputs[0, -1].item() |
| |
| |
| if new_token_id in self.stop_token_ids: |
| break |
| |
| |
| new_token = self.tokenizer.decode([new_token_id]) |
| yield new_token |
| |
| |
| inputs = { |
| 'input_ids': outputs, |
| 'attention_mask': torch.ones_like(outputs) |
| } |
| |
| def get_embeddings(self, text: str) -> torch.Tensor: |
| """ |
| Get embeddings for text |
| |
| Args: |
| text: Input text |
| |
| Returns: |
| Embedding tensor |
| """ |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs, output_hidden_states=True) |
| embeddings = outputs.hidden_states[-1].mean(dim=1) |
| |
| return embeddings |
| |
| def score_text(self, text: str) -> float: |
| """ |
| Calculate perplexity score for text |
| |
| Args: |
| text: Input text |
| |
| Returns: |
| Perplexity score |
| """ |
| inputs = self.tokenizer(text, return_tensors="pt").to(self.device) |
| |
| with torch.no_grad(): |
| outputs = self.model(**inputs, labels=inputs['input_ids']) |
| loss = outputs.loss |
| perplexity = torch.exp(loss).item() |
| |
| return perplexity |
| |
| def cleanup(self): |
| """Clean up resources""" |
| del self.model |
| del self.tokenizer |
| torch.cuda.empty_cache() |
| logger.info("Pipeline cleaned up") |
|
|
|
|
| class ConversationPipeline(HelionPipeline): |
| """Pipeline with conversation history management""" |
| |
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.conversation_history: List[Dict[str, str]] = [] |
| self.system_prompt: Optional[str] = None |
| |
| def set_system_prompt(self, prompt: str): |
| """Set system prompt for conversation""" |
| self.system_prompt = prompt |
| |
| def add_message(self, role: str, content: str): |
| """Add message to conversation history""" |
| self.conversation_history.append({ |
| 'role': role, |
| 'content': content |
| }) |
| |
| def generate_response( |
| self, |
| user_message: str, |
| max_new_tokens: int = 512, |
| temperature: float = 0.7, |
| **kwargs |
| ) -> str: |
| """ |
| Generate response in conversation context |
| |
| Args: |
| user_message: User's message |
| max_new_tokens: Maximum tokens to generate |
| temperature: Sampling temperature |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Assistant response |
| """ |
| |
| messages = [] |
| |
| if self.system_prompt: |
| messages.append({ |
| 'role': 'system', |
| 'content': self.system_prompt |
| }) |
| |
| messages.extend(self.conversation_history) |
| messages.append({ |
| 'role': 'user', |
| 'content': user_message |
| }) |
| |
| |
| response = self.chat( |
| messages, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| **kwargs |
| ) |
| |
| |
| self.add_message('user', user_message) |
| self.add_message('assistant', response) |
| |
| return response |
| |
| def reset_conversation(self): |
| """Reset conversation history""" |
| self.conversation_history.clear() |
| logger.info("Conversation history reset") |
|
|
|
|
| def main(): |
| """Example usage""" |
| |
| pipeline = HelionPipeline( |
| model_path="DeepXR/Helion-2.5-Rnd", |
| device="cuda" |
| ) |
| |
| |
| prompt = "Explain quantum computing in simple terms:" |
| response = pipeline.generate(prompt, max_new_tokens=256) |
| print(f"Response: {response}\n") |
| |
| |
| messages = [ |
| {"role": "system", "content": "You are a helpful assistant."}, |
| {"role": "user", "content": "What is the capital of France?"} |
| ] |
| response = pipeline.chat(messages) |
| print(f"Chat response: {response}\n") |
| |
| |
| prompts = [ |
| "Write a haiku about AI:", |
| "Explain machine learning:", |
| "What is Python?" |
| ] |
| responses = pipeline.batch_generate(prompts, batch_size=2) |
| for i, resp in enumerate(responses): |
| print(f"Batch {i+1}: {resp}\n") |
| |
| |
| conv_pipeline = ConversationPipeline( |
| model_path="DeepXR/Helion-2.5-Rnd", |
| device="cuda" |
| ) |
| conv_pipeline.set_system_prompt("You are a helpful coding assistant.") |
| |
| response1 = conv_pipeline.generate_response("How do I sort a list in Python?") |
| print(f"Conv 1: {response1}\n") |
| |
| response2 = conv_pipeline.generate_response("Can you show me an example?") |
| print(f"Conv 2: {response2}\n") |
|
|
|
|
| if __name__ == "__main__": |
| main() |