| """ |
| Helion-OSC Inference Script |
| DeepXR/Helion-OSC - Mathematical Coding Language Model |
| |
| This module provides comprehensive inference capabilities for the Helion-OSC model, |
| including specialized methods for different programming and mathematical tasks. |
| """ |
|
|
| import torch |
| import json |
| import logging |
| from typing import Optional, Dict, Any, List, Union |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForCausalLM, |
| GenerationConfig, |
| StoppingCriteria, |
| StoppingCriteriaList |
| ) |
| from dataclasses import dataclass |
| import warnings |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| @dataclass |
| class GenerationParameters: |
| """Parameters for text generation""" |
| max_length: int = 2048 |
| temperature: float = 0.7 |
| top_p: float = 0.95 |
| top_k: int = 50 |
| repetition_penalty: float = 1.05 |
| length_penalty: float = 1.0 |
| do_sample: bool = True |
| num_return_sequences: int = 1 |
| early_stopping: bool = False |
|
|
|
|
| class CodeStoppingCriteria(StoppingCriteria): |
| """Custom stopping criteria for code generation""" |
| |
| def __init__(self, stop_sequences: List[str], tokenizer): |
| self.stop_sequences = stop_sequences |
| self.tokenizer = tokenizer |
| |
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: |
| decoded = self.tokenizer.decode(input_ids[0], skip_special_tokens=True) |
| return any(seq in decoded for seq in self.stop_sequences) |
|
|
|
|
| class HelionOSCInference: |
| """ |
| Comprehensive inference wrapper for Helion-OSC model |
| |
| Supports multiple generation modes: |
| - Code generation |
| - Mathematical reasoning |
| - Algorithm design |
| - Code debugging |
| - Documentation generation |
| """ |
| |
| def __init__( |
| self, |
| model_name: str = "DeepXR/Helion-OSC", |
| device: Optional[str] = None, |
| load_in_8bit: bool = False, |
| load_in_4bit: bool = False, |
| use_flash_attention: bool = True, |
| trust_remote_code: bool = True |
| ): |
| """ |
| Initialize the Helion-OSC model |
| |
| Args: |
| model_name: HuggingFace model identifier |
| device: Device to load model on (cuda/cpu/mps) |
| load_in_8bit: Load model in 8-bit precision |
| load_in_4bit: Load model in 4-bit precision |
| use_flash_attention: Use flash attention for faster inference |
| trust_remote_code: Trust remote code from model repository |
| """ |
| self.model_name = model_name |
| self.device = self._get_device(device) |
| self.load_in_8bit = load_in_8bit |
| self.load_in_4bit = load_in_4bit |
| |
| logger.info(f"Initializing Helion-OSC on {self.device}...") |
| |
| |
| self.tokenizer = self._load_tokenizer(trust_remote_code) |
| |
| |
| self.model = self._load_model( |
| use_flash_attention=use_flash_attention, |
| trust_remote_code=trust_remote_code |
| ) |
| |
| |
| self.generation_configs = self._load_generation_configs() |
| |
| logger.info("Model loaded successfully!") |
| self._print_model_info() |
| |
| def _get_device(self, device: Optional[str]) -> str: |
| """Determine the best available device""" |
| if device: |
| return device |
| if torch.cuda.is_available(): |
| return "cuda" |
| elif torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
| |
| def _load_tokenizer(self, trust_remote_code: bool): |
| """Load and configure tokenizer""" |
| logger.info("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained( |
| self.model_name, |
| trust_remote_code=trust_remote_code, |
| padding_side="left" |
| ) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| return tokenizer |
| |
| def _load_model(self, use_flash_attention: bool, trust_remote_code: bool): |
| """Load and configure model""" |
| logger.info("Loading model...") |
| |
| model_kwargs = { |
| "trust_remote_code": trust_remote_code, |
| "low_cpu_mem_usage": True |
| } |
| |
| |
| if self.load_in_8bit: |
| model_kwargs["load_in_8bit"] = True |
| logger.info("Loading in 8-bit precision") |
| elif self.load_in_4bit: |
| model_kwargs["load_in_4bit"] = True |
| model_kwargs["bnb_4bit_compute_dtype"] = torch.bfloat16 |
| model_kwargs["bnb_4bit_use_double_quant"] = True |
| model_kwargs["bnb_4bit_quant_type"] = "nf4" |
| logger.info("Loading in 4-bit precision") |
| else: |
| if self.device == "cuda": |
| model_kwargs["torch_dtype"] = torch.bfloat16 |
| else: |
| model_kwargs["torch_dtype"] = torch.float32 |
| |
| |
| if self.device == "cuda" and not (self.load_in_8bit or self.load_in_4bit): |
| model_kwargs["device_map"] = "auto" |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| self.model_name, |
| **model_kwargs |
| ) |
| |
| |
| if self.device != "cuda" or (self.load_in_8bit or self.load_in_4bit): |
| if not (self.load_in_8bit or self.load_in_4bit): |
| model = model.to(self.device) |
| |
| model.eval() |
| |
| |
| if hasattr(model, 'gradient_checkpointing_enable'): |
| model.gradient_checkpointing_enable() |
| |
| return model |
| |
| def _load_generation_configs(self) -> Dict[str, GenerationParameters]: |
| """Load task-specific generation configurations""" |
| return { |
| "code_generation": GenerationParameters( |
| max_length=4096, |
| temperature=0.7, |
| top_p=0.95, |
| top_k=50, |
| repetition_penalty=1.05, |
| do_sample=True |
| ), |
| "mathematical_reasoning": GenerationParameters( |
| max_length=2048, |
| temperature=0.3, |
| top_p=0.9, |
| top_k=40, |
| repetition_penalty=1.0, |
| do_sample=False |
| ), |
| "code_completion": GenerationParameters( |
| max_length=1024, |
| temperature=0.6, |
| top_p=0.92, |
| top_k=45, |
| repetition_penalty=1.03, |
| do_sample=True |
| ), |
| "algorithm_design": GenerationParameters( |
| max_length=3072, |
| temperature=0.5, |
| top_p=0.93, |
| top_k=50, |
| repetition_penalty=1.08, |
| do_sample=True |
| ), |
| "debugging": GenerationParameters( |
| max_length=2048, |
| temperature=0.4, |
| top_p=0.88, |
| repetition_penalty=1.0, |
| do_sample=False |
| ) |
| } |
| |
| def _print_model_info(self): |
| """Print model information""" |
| try: |
| num_params = sum(p.numel() for p in self.model.parameters()) |
| logger.info(f"Model parameters: {num_params:,}") |
| logger.info(f"Model dtype: {next(self.model.parameters()).dtype}") |
| logger.info(f"Device: {self.device}") |
| except Exception as e: |
| logger.warning(f"Could not get model info: {e}") |
| |
| def generate( |
| self, |
| prompt: Union[str, List[str]], |
| task_type: str = "code_generation", |
| custom_params: Optional[GenerationParameters] = None, |
| stop_sequences: Optional[List[str]] = None, |
| return_full_text: bool = False, |
| **kwargs |
| ) -> Union[str, List[str]]: |
| """ |
| Generate text based on prompt |
| |
| Args: |
| prompt: Input prompt or list of prompts |
| task_type: Type of task (code_generation, mathematical_reasoning, etc.) |
| custom_params: Custom generation parameters |
| stop_sequences: List of sequences to stop generation |
| return_full_text: Whether to return full text including prompt |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Generated text or list of generated texts |
| """ |
| |
| if custom_params: |
| params = custom_params |
| elif task_type in self.generation_configs: |
| params = self.generation_configs[task_type] |
| else: |
| logger.warning(f"Unknown task type '{task_type}', using default parameters") |
| params = GenerationParameters() |
| |
| |
| for key, value in kwargs.items(): |
| if hasattr(params, key): |
| setattr(params, key, value) |
| |
| |
| is_batch = isinstance(prompt, list) |
| inputs = self.tokenizer( |
| prompt, |
| return_tensors="pt", |
| padding=True, |
| truncation=True, |
| max_length=self.model.config.max_position_embeddings |
| ).to(self.device) |
| |
| |
| stopping_criteria = None |
| if stop_sequences: |
| stopping_criteria = StoppingCriteriaList([ |
| CodeStoppingCriteria(stop_sequences, self.tokenizer) |
| ]) |
| |
| |
| with torch.no_grad(): |
| outputs = self.model.generate( |
| **inputs, |
| max_length=params.max_length, |
| temperature=params.temperature, |
| top_p=params.top_p, |
| top_k=params.top_k, |
| repetition_penalty=params.repetition_penalty, |
| length_penalty=params.length_penalty, |
| do_sample=params.do_sample, |
| num_return_sequences=params.num_return_sequences, |
| early_stopping=params.early_stopping, |
| pad_token_id=self.tokenizer.pad_token_id, |
| eos_token_id=self.tokenizer.eos_token_id, |
| stopping_criteria=stopping_criteria |
| ) |
| |
| |
| generated_texts = [] |
| for output in outputs: |
| text = self.tokenizer.decode(output, skip_special_tokens=True) |
| if not return_full_text and not is_batch: |
| |
| if isinstance(prompt, str): |
| text = text[len(prompt):].strip() |
| generated_texts.append(text) |
| |
| return generated_texts if is_batch or params.num_return_sequences > 1 else generated_texts[0] |
| |
| def code_generation( |
| self, |
| prompt: str, |
| language: Optional[str] = None, |
| max_length: int = 4096, |
| **kwargs |
| ) -> str: |
| """ |
| Generate code for a given prompt |
| |
| Args: |
| prompt: Code generation prompt |
| language: Programming language (optional) |
| max_length: Maximum length of generated code |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Generated code |
| """ |
| if language: |
| prompt = f"Language: {language}\n{prompt}" |
| |
| return self.generate( |
| prompt, |
| task_type="code_generation", |
| max_length=max_length, |
| **kwargs |
| ) |
| |
| def mathematical_reasoning( |
| self, |
| prompt: str, |
| max_length: int = 2048, |
| **kwargs |
| ) -> str: |
| """ |
| Solve mathematical problems with step-by-step reasoning |
| |
| Args: |
| prompt: Mathematical problem |
| max_length: Maximum length of solution |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Mathematical solution with reasoning |
| """ |
| return self.generate( |
| prompt, |
| task_type="mathematical_reasoning", |
| max_length=max_length, |
| **kwargs |
| ) |
| |
| def algorithm_design( |
| self, |
| prompt: str, |
| include_complexity: bool = True, |
| max_length: int = 3072, |
| **kwargs |
| ) -> str: |
| """ |
| Design algorithms with complexity analysis |
| |
| Args: |
| prompt: Algorithm design prompt |
| include_complexity: Whether to include complexity analysis |
| max_length: Maximum length of output |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Algorithm design with analysis |
| """ |
| if include_complexity: |
| prompt += "\n\nPlease include time and space complexity analysis." |
| |
| return self.generate( |
| prompt, |
| task_type="algorithm_design", |
| max_length=max_length, |
| **kwargs |
| ) |
| |
| def debug_code( |
| self, |
| code: str, |
| error_message: Optional[str] = None, |
| max_length: int = 2048, |
| **kwargs |
| ) -> str: |
| """ |
| Debug code and provide fixes |
| |
| Args: |
| code: Code to debug |
| error_message: Optional error message |
| max_length: Maximum length of output |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Debugging analysis and fixes |
| """ |
| prompt = f"Debug the following code:\n\n```\n{code}\n```" |
| if error_message: |
| prompt += f"\n\nError message: {error_message}" |
| prompt += "\n\nProvide a detailed explanation and fixed code." |
| |
| return self.generate( |
| prompt, |
| task_type="debugging", |
| max_length=max_length, |
| **kwargs |
| ) |
| |
| def complete_code( |
| self, |
| code_context: str, |
| max_length: int = 1024, |
| **kwargs |
| ) -> str: |
| """ |
| Complete partial code |
| |
| Args: |
| code_context: Partial code to complete |
| max_length: Maximum length of completion |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| Code completion |
| """ |
| return self.generate( |
| code_context, |
| task_type="code_completion", |
| max_length=max_length, |
| stop_sequences=["\n\n", "```", "###"], |
| **kwargs |
| ) |
| |
| def batch_generate( |
| self, |
| prompts: List[str], |
| task_type: str = "code_generation", |
| batch_size: int = 4, |
| **kwargs |
| ) -> List[str]: |
| """ |
| Generate responses for multiple prompts in batches |
| |
| Args: |
| prompts: List of prompts |
| task_type: Type of task |
| batch_size: Batch size for processing |
| **kwargs: Additional generation parameters |
| |
| Returns: |
| List of generated responses |
| """ |
| results = [] |
| for i in range(0, len(prompts), batch_size): |
| batch = prompts[i:i + batch_size] |
| batch_results = self.generate(batch, task_type=task_type, **kwargs) |
| if isinstance(batch_results, str): |
| batch_results = [batch_results] |
| results.extend(batch_results) |
| return results |
|
|
|
|
| def main(): |
| """Example usage and demonstrations""" |
| print("=" * 80) |
| print("Helion-OSC Inference Examples") |
| print("=" * 80) |
| |
| |
| helion = HelionOSCInference( |
| load_in_8bit=False, |
| load_in_4bit=False |
| ) |
| |
| |
| print("\n" + "=" * 80) |
| print("Example 1: Code Generation") |
| print("=" * 80) |
| code_prompt = """Write a Python function to implement a binary search tree with the following methods: |
| - insert(value): Insert a new value |
| - search(value): Search for a value |
| - delete(value): Delete a value |
| - inorder_traversal(): Return inorder traversal |
| |
| Include proper documentation and type hints.""" |
| |
| print(f"\nPrompt:\n{code_prompt}") |
| print("\nGenerating...") |
| result = helion.code_generation(code_prompt, language="python") |
| print(f"\nGenerated Code:\n{result}") |
| |
| |
| print("\n" + "=" * 80) |
| print("Example 2: Mathematical Reasoning") |
| print("=" * 80) |
| math_prompt = """Prove that the sum of the first n natural numbers equals n(n+1)/2 using mathematical induction.""" |
| |
| print(f"\nPrompt:\n{math_prompt}") |
| print("\nGenerating...") |
| result = helion.mathematical_reasoning(math_prompt) |
| print(f"\nSolution:\n{result}") |
| |
| |
| print("\n" + "=" * 80) |
| print("Example 3: Algorithm Design") |
| print("=" * 80) |
| algo_prompt = """Design an efficient algorithm to find the longest palindromic substring in a given string.""" |
| |
| print(f"\nPrompt:\n{algo_prompt}") |
| print("\nGenerating...") |
| result = helion.algorithm_design(algo_prompt, include_complexity=True) |
| print(f"\nAlgorithm:\n{result}") |
| |
| |
| print("\n" + "=" * 80) |
| print("Example 4: Code Debugging") |
| print("=" * 80) |
| buggy_code = """ |
| def fibonacci(n): |
| if n <= 1: |
| return n |
| return fibonacci(n-1) + fibonacci(n-2) |
| |
| # This is too slow for large n |
| result = fibonacci(100) |
| """ |
| |
| print(f"\nBuggy Code:\n{buggy_code}") |
| print("\nGenerating debugging analysis...") |
| result = helion.debug_code(buggy_code, error_message="Takes too long to compute") |
| print(f"\nDebug Analysis:\n{result}") |
| |
| |
| print("\n" + "=" * 80) |
| print("Example 5: Batch Code Generation") |
| print("=" * 80) |
| batch_prompts = [ |
| "Write a Python function to reverse a linked list", |
| "Write a JavaScript function to debounce API calls", |
| "Write a Rust function to parse JSON safely" |
| ] |
| |
| print("\nProcessing batch prompts...") |
| results = helion.batch_generate(batch_prompts, batch_size=2) |
| for i, (prompt, result) in enumerate(zip(batch_prompts, results), 1): |
| print(f"\nPrompt {i}: {prompt}") |
| print(f"Result {i}:\n{result}\n") |
| |
| print("=" * 80) |
| print("Examples completed!") |
| print("=" * 80) |
|
|
|
|
| if __name__ == "__main__": |
| main() |