| """ |
| SHIVIK-Code Model Implementation |
| |
| This is a modified version of SHIVIK-M4 with: |
| - Extended context length (32K) via YaRN RoPE scaling |
| - Tool calling capabilities |
| - Fill-in-the-Middle support |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import LlamaForCausalLM, LlamaConfig |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
| from typing import Optional, Tuple, List, Union |
|
|
|
|
| class ShivikCodeConfig(LlamaConfig): |
| """Configuration for SHIVIK-Code model.""" |
| |
| model_type = "shivik_code" |
| |
| def __init__( |
| self, |
| vocab_size=128279, |
| hidden_size=2048, |
| intermediate_size=8192, |
| num_hidden_layers=16, |
| num_attention_heads=32, |
| num_key_value_heads=8, |
| max_position_embeddings=32768, |
| rope_theta=500000.0, |
| rope_scaling=None, |
| **kwargs |
| ): |
| |
| if rope_scaling is None: |
| rope_scaling = { |
| "type": "yarn", |
| "factor": 8.0, |
| "original_max_position_embeddings": 4096, |
| } |
| |
| super().__init__( |
| vocab_size=vocab_size, |
| hidden_size=hidden_size, |
| intermediate_size=intermediate_size, |
| num_hidden_layers=num_hidden_layers, |
| num_attention_heads=num_attention_heads, |
| num_key_value_heads=num_key_value_heads, |
| max_position_embeddings=max_position_embeddings, |
| rope_theta=rope_theta, |
| rope_scaling=rope_scaling, |
| **kwargs |
| ) |
| |
| |
| self.tool_call_start_id = None |
| self.tool_call_end_id = None |
| self.tool_result_start_id = None |
| self.tool_result_end_id = None |
|
|
|
|
| class ShivikCodeForCausalLM(LlamaForCausalLM): |
| """ |
| SHIVIK-Code: An agentic coding model. |
| |
| Extends LlamaForCausalLM with: |
| - Tool calling support |
| - Extended context via YaRN |
| - FIM capability |
| """ |
| |
| config_class = ShivikCodeConfig |
| |
| def __init__(self, config: ShivikCodeConfig): |
| super().__init__(config) |
| |
| |
| self.tool_tokens = { |
| "call_start": config.tool_call_start_id, |
| "call_end": config.tool_call_end_id, |
| "result_start": config.tool_result_start_id, |
| "result_end": config.tool_result_end_id, |
| } |
| |
| def is_tool_call(self, token_id: int) -> bool: |
| """Check if token is a tool call token.""" |
| return token_id in [ |
| self.tool_tokens["call_start"], |
| self.tool_tokens["call_end"], |
| ] |
| |
| def generate_with_tools( |
| self, |
| input_ids: torch.Tensor, |
| tool_executor, |
| max_new_tokens: int = 512, |
| max_tool_calls: int = 10, |
| **generate_kwargs |
| ): |
| """ |
| Generate with automatic tool execution. |
| |
| Args: |
| input_ids: Input token IDs |
| tool_executor: Function that takes tool call JSON and returns result |
| max_new_tokens: Max tokens per generation step |
| max_tool_calls: Max number of tool calls allowed |
| |
| Returns: |
| Full generated sequence including tool results |
| """ |
| current_ids = input_ids |
| tool_call_count = 0 |
| |
| while tool_call_count < max_tool_calls: |
| |
| outputs = self.generate( |
| current_ids, |
| max_new_tokens=max_new_tokens, |
| stop_strings=["</tool_call>"], |
| **generate_kwargs |
| ) |
| |
| generated = outputs[0] |
| |
| |
| if self._contains_tool_call(generated): |
| |
| tool_call = self._extract_tool_call(generated) |
| tool_result = tool_executor(tool_call) |
| |
| |
| result_tokens = self._format_tool_result(tool_result) |
| current_ids = torch.cat([generated, result_tokens], dim=-1) |
| tool_call_count += 1 |
| else: |
| |
| return generated |
| |
| return current_ids |
| |
| def _contains_tool_call(self, token_ids: torch.Tensor) -> bool: |
| """Check if sequence contains a tool call.""" |
| |
| pass |
| |
| def _extract_tool_call(self, token_ids: torch.Tensor) -> dict: |
| """Extract tool call JSON from sequence.""" |
| |
| pass |
| |
| def _format_tool_result(self, result: str) -> torch.Tensor: |
| """Format tool result as tokens.""" |
| |
| pass |
|
|
|
|
| |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| AutoConfig.register("shivik_code", ShivikCodeConfig) |
| AutoModelForCausalLM.register(ShivikCodeConfig, ShivikCodeForCausalLM) |
|
|