| import os |
| import torch |
| import torch.nn.functional as F |
| from collections import OrderedDict |
| import string |
| import sys |
| from model import ChatGCLM, MAX_SEQ_LEN |
|
|
| |
| EOS_ID = 2 |
| OFFSET = 3 |
| CHARS = string.printable |
|
|
| def get_model_path(): |
| """Finds the first model file starting with Turing_ in the current directory.""" |
| for f in os.listdir("."): |
| if f.startswith("Turing_") and f.endswith(".pt"): |
| return f |
| return None |
|
|
| MODEL_PATH = get_model_path() |
|
|
| if MODEL_PATH is None: |
| print("Error: No model checkpoint found!") |
| print("Please train the model first with: python3 train.py") |
| sys.exit(1) |
|
|
| |
|
|
| def encode(text): |
| return [CHARS.index(c) + OFFSET for c in text if c in CHARS] |
|
|
| def decode(ids): |
| return "".join([CHARS[i - OFFSET] for i in ids if i >= OFFSET]) |
|
|
| def load_model(device): |
| vocab_size = len(CHARS) + OFFSET |
| |
| model = ChatGCLM(vocab_size).to(device) |
| if os.path.exists(MODEL_PATH) and os.path.getsize(MODEL_PATH) > 0: |
| print(f"Loading model from: {MODEL_PATH}") |
| ckpt = torch.load(MODEL_PATH, map_location=device) |
|
|
| if isinstance(ckpt, dict): |
| if 'model_state_dict' in ckpt: |
| state_dict = ckpt['model_state_dict'] |
| elif 'state_dict' in ckpt: |
| state_dict = ckpt['state_dict'] |
| else: |
| state_dict = ckpt |
| else: |
| state_dict = ckpt |
|
|
| |
| def _strip_module_prefix(sd): |
| keys = list(sd.keys()) |
| if any(k.startswith('module.') for k in keys): |
| new_sd = OrderedDict() |
| for k, v in sd.items(): |
| new_key = k[len('module.'): ] if k.startswith('module.') else k |
| new_sd[new_key] = v |
| return new_sd |
| return sd |
|
|
| state_dict = _strip_module_prefix(state_dict) |
|
|
| res = model.load_state_dict(state_dict, strict=False) |
| missing = getattr(res, 'missing_keys', None) |
| unexpected = getattr(res, 'unexpected_keys', None) |
| if missing: |
| print(f"Warning: missing keys when loading state_dict: {missing}") |
| if unexpected: |
| print(f"Warning: unexpected keys in state_dict: {unexpected}") |
|
|
| model.eval() |
| return model |
| else: |
| print(f"Error: Could not load model from {MODEL_PATH}") |
| return None |
|
|
| @torch.no_grad() |
| def generate_stream(model, prompt, device, max_new_tokens=500, temperature=0.7, top_k=50): |
| """ |
| Generates text from the model and streams it to stdout. |
| Returns the full generated text. |
| """ |
| model.eval() |
| input_ids = encode(prompt) |
| x = torch.tensor([input_ids], dtype=torch.long, device=device) |
| |
| |
| generated_ids = [] |
| |
| for _ in range(max_new_tokens): |
| |
| ctx = x[:, -MAX_SEQ_LEN:] if x.size(1) > MAX_SEQ_LEN else x |
| |
| logits = model(ctx) |
| next_token_logits = logits[:, -1, :] / temperature |
| |
| if top_k is not None: |
| v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))) |
| next_token_logits[next_token_logits < v[:, [-1]]] = -float('Inf') |
| |
| probs = F.softmax(next_token_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| idx = next_token.item() |
| |
| if idx == EOS_ID: |
| break |
| |
| x = torch.cat((x, next_token), dim=1) |
| generated_ids.append(idx) |
| |
| token_text = decode([idx]) |
| print(token_text, end="", flush=True) |
| |
| if len(generated_ids) >= 3 and decode(generated_ids[-3:]) == "<u>": |
| print('\b\b\b \b\b\b', end="", flush=True) |
| break |
| |
| return decode(generated_ids) |
|
|
| def main(): |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| print(f"Using device: {device}") |
| |
| model = load_model(device) |
| |
| if model is None: |
| sys.exit(1) |
| |
| print("\n" + "="*50) |
| print("Turing | Chat Interface") |
| print(f"Model: {MODEL_PATH}") |
| print("Type 'quit', 'exit', or 'q' to end the session.") |
| print("="*50 + "\n") |
| |
| history = "" |
| while True: |
| try: |
| |
| user_input = input("\n\033[1;36mYou:\033[0m ") |
| |
| if user_input.strip().lower() in ['quit', 'exit', 'q']: |
| print("\nGoodbye!") |
| break |
| |
| if not user_input.strip(): |
| continue |
|
|
| print("\033[1;32mModel:\033[0m ", end="", flush=True) |
| |
| |
| |
| |
| current_turn = f"<u> {user_input} <a>" |
| full_prompt = history + current_turn |
| |
| |
| response = generate_stream(model, full_prompt, device=device) |
| |
| |
| |
| cleaned_response = response |
| if cleaned_response.endswith("<u>"): |
| cleaned_response = cleaned_response[:-3] |
| |
| history += current_turn + cleaned_response |
| print() |
|
|
| except KeyboardInterrupt: |
| print("\n\nExiting...") |
| break |
| except Exception as e: |
| print(f"\nError: {e}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|