| import torch |
| import torch.nn as nn |
| import unicodedata |
| import os |
| import gradio as gr |
| from transformers import PreTrainedTokenizerFast, PretrainedConfig, PreTrainedModel |
| from tokenizers import decoders |
|
|
| |
| class IsaiConfig(PretrainedConfig): |
| model_type = "isai" |
| def __init__(self, vocab_size=32000, hidden_size=1024, intermediate_size=2816, num_hidden_layers=24, num_attention_heads=16, num_key_value_heads=16, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs): |
| super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.num_key_value_heads = num_key_value_heads |
| self.max_position_embeddings = max_position_embeddings |
| self.rms_norm_eps = rms_norm_eps |
|
|
| class IsaiRMSNorm(nn.Module): |
| def __init__(self, hidden_size, eps=1e-6): |
| super().__init__() |
| self.weight = nn.Parameter(torch.ones(hidden_size)) |
| self.variance_epsilon = eps |
| def forward(self, hidden_states): |
| input_dtype = hidden_states.dtype |
| hidden_states = hidden_states.to(torch.float32) |
| variance = hidden_states.pow(2).mean(-1, keepdim=True) |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) |
| return self.weight * hidden_states.to(input_dtype) |
|
|
| class IsaiForCausalLM(PreTrainedModel): |
| config_class = IsaiConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.model = nn.ModuleDict({ |
| "embed_tokens": nn.Embedding(config.vocab_size, config.hidden_size), |
| "layers": nn.ModuleList([nn.ModuleDict({ |
| "input_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps), |
| "post_attention_layernorm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps), |
| "self_attn": nn.Linear(config.hidden_size, config.hidden_size, bias=False), |
| "mlp": nn.ModuleDict({ |
| "gate_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False), |
| "up_proj": nn.Linear(config.hidden_size, config.intermediate_size, bias=False), |
| "down_proj": nn.Linear(config.intermediate_size, config.hidden_size, bias=False), |
| }) |
| }) for _ in range(config.num_hidden_layers)]), |
| "norm": IsaiRMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| }) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self.post_init() |
|
|
| def forward(self, input_ids=None, **kwargs): |
| hidden_states = self.model.embed_tokens(input_ids) |
| for layer in self.model.layers: |
| h = layer.input_layernorm(hidden_states) |
| hidden_states = hidden_states + layer.self_attn(h) |
| h = layer.post_attention_layernorm(hidden_states) |
| hidden_states = hidden_states + layer.mlp.down_proj(nn.functional.silu(layer.mlp.gate_proj(h)) * layer.mlp.up_proj(h)) |
| logits = self.lm_head(self.model.norm(hidden_states)) |
| return logits |
|
|
| |
| model_dir = "models/isai-v4.2" |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| tokenizer = PreTrainedTokenizerFast.from_pretrained(model_dir) |
| tokenizer._tokenizer.decoder = decoders.ByteLevel() |
|
|
| config = IsaiConfig.from_pretrained(model_dir) |
| model = IsaiForCausalLM(config).to(device) |
|
|
| |
| weights_path = os.path.join(model_dir, "model.safetensors") |
| if os.path.exists(weights_path): |
| from safetensors.torch import load_file |
| model.load_state_dict(load_file(weights_path)) |
| else: |
| model.load_state_dict(torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location=device)) |
| model.eval() |
|
|
| |
| def predict(message, history): |
| |
| decomposed_input = unicodedata.normalize('NFD', message) |
| input_ids = tokenizer.encode(decomposed_input, return_tensors="pt").to(device) |
| |
| current_ids = input_ids |
| max_new_tokens = 50 |
|
|
| |
| for _ in range(max_new_tokens): |
| with torch.no_grad(): |
| logits = model(current_ids) |
| next_token = torch.argmax(logits[:, -1, :], dim=-1).unsqueeze(0) |
| current_ids = torch.cat([current_ids, next_token], dim=-1) |
| if next_token.item() == tokenizer.eos_token_id: |
| break |
|
|
| |
| |
| generated_tokens = current_ids[0][input_ids.shape[1]:] |
| raw_response = tokenizer.decode(generated_tokens, skip_special_tokens=True) |
| final_response = unicodedata.normalize('NFC', raw_response) |
| |
| return final_response |
|
|
| |
| demo = gr.ChatInterface( |
| fn=predict, |
| title="isai-v4.2 Jaso-Level Chat", |
| description="μμ λ¨μ(NFD)λ‘ μν΅νλ μ΄μν μΌμ λν λͺ¨λΈμ
λλ€. μ
λ ₯μ μλμΌλ‘ λΆν΄λκ³ μΆλ ₯μ λ€μ νκΈλ‘ μ‘°ν©λ©λλ€.", |
| examples=["μλ
? λ°κ°μ.", "μ€λ λ μ¨κ° μ΄λ?", "λμ μ΄λ¦μ λμΌ?"] |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(share=True) |