| |
| """ |
| AgGPT-21 Interactive Chat Interface |
| A conversational interface for the trained AgGPT-21 model. |
| """ |
|
|
| import os |
| import sys |
| import torch |
| from AgGPT21 import WordRNN, generate_text, MODEL_FILE, DEVICE |
|
|
| def load_model(): |
| """Load the trained AgGPT-21 model.""" |
| if not os.path.exists(MODEL_FILE): |
| print(f"β Model file '{MODEL_FILE}' not found!") |
| print("Please train the model first by running: python AgGPT21.py") |
| sys.exit(1) |
| |
| try: |
| print("π Loading AgGPT-21 model...") |
| ckpt = torch.load(MODEL_FILE, map_location=DEVICE) |
| stoi = ckpt["stoi"] |
| itos = ckpt["itos"] |
| model = WordRNN(len(stoi)) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
| |
| |
| param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"β
Model loaded successfully!") |
| print(f" β’ Parameters: {param_count:,}") |
| print(f" β’ Vocabulary size: {len(stoi):,}") |
| print(f" β’ Device: {DEVICE}") |
| print() |
| |
| return model, stoi, itos |
| except Exception as e: |
| print(f"β Error loading model: {e}") |
| sys.exit(1) |
|
|
| def print_banner(): |
| """Display the AgGPT-21 banner.""" |
| banner = """ |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| β π€ AgGPT-21 π€ β |
| β Interactive Chat Interface β |
| β β |
| β β’ Type your message and press Enter to chat β |
| β β’ Use 'quit', 'exit', or 'bye' to end the conversation β |
| β β’ Use 'help' for more options β |
| ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ |
| """ |
| print(banner) |
|
|
| def print_help(): |
| """Display help information.""" |
| help_text = """ |
| π§ AgGPT-21 Chat Commands: |
| β’ Just type your message to chat with the AI |
| β’ 'quit', 'exit', 'bye' - End the conversation |
| β’ 'help' - Show this help message |
| β’ 'clear' - Clear the screen |
| β’ 'model' - Show model information |
| β’ 'temp X' - Set temperature (e.g., 'temp 0.8') |
| β’ 'length X' - Set response length (e.g., 'length 150') |
| |
| ποΈ Current Settings: |
| β’ Temperature: Controls creativity (0.1-2.0, default: 0.9) |
| β’ Length: Number of words to generate (50-500, default: 200) |
| """ |
| print(help_text) |
|
|
| def main(): |
| """Main chat loop.""" |
| print_banner() |
| |
| |
| model, stoi, itos = load_model() |
| |
| |
| temperature = 0.9 |
| length = 200 |
| top_k = 50 |
| top_p = 0.9 |
| |
| print("π¬ Chat started! Type your message below:") |
| print("="*70) |
| |
| while True: |
| try: |
| |
| user_input = input("\nπ€ You: ").strip() |
| |
| if not user_input: |
| continue |
| |
| |
| user_lower = user_input.lower() |
| |
| if user_lower in ['quit', 'exit', 'bye']: |
| print("\nπ Goodbye! Thanks for chatting with AgGPT-21!") |
| break |
| |
| elif user_lower == 'help': |
| print_help() |
| continue |
| |
| elif user_lower == 'clear': |
| os.system('clear' if os.name == 'posix' else 'cls') |
| print_banner() |
| continue |
| |
| elif user_lower == 'model': |
| param_count = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"\nπ€ Model Information:") |
| print(f" β’ Parameters: {param_count:,}") |
| print(f" β’ Vocabulary: {len(stoi):,} words") |
| print(f" β’ Device: {DEVICE}") |
| print(f" β’ Temperature: {temperature}") |
| print(f" β’ Max length: {length}") |
| continue |
| |
| elif user_lower.startswith('temp '): |
| try: |
| new_temp = float(user_lower.split()[1]) |
| if 0.1 <= new_temp <= 2.0: |
| temperature = new_temp |
| print(f"π‘οΈ Temperature set to {temperature}") |
| else: |
| print("β Temperature must be between 0.1 and 2.0") |
| except (IndexError, ValueError): |
| print("β Invalid temperature. Use: temp 0.8") |
| continue |
| |
| elif user_lower.startswith('length '): |
| try: |
| new_length = int(user_lower.split()[1]) |
| if 50 <= new_length <= 500: |
| length = new_length |
| print(f"π Response length set to {length} words") |
| else: |
| print("β Length must be between 50 and 500") |
| except (IndexError, ValueError): |
| print("β Invalid length. Use: length 150") |
| continue |
| |
| |
| print(f"\nπ€ AgGPT-21 (thinking...)", end="", flush=True) |
| |
| try: |
| response = generate_text( |
| model=model, |
| stoi=stoi, |
| itos=itos, |
| prompt=user_input, |
| length=length, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| device=DEVICE |
| ) |
| |
| |
| response_words = response.split() |
| prompt_words = user_input.lower().split() |
| |
| |
| if len(response_words) > len(prompt_words): |
| ai_response = " ".join(response_words[len(prompt_words):]) |
| else: |
| ai_response = response |
| |
| print(f"\rπ€ AgGPT-21: {ai_response}") |
| |
| except Exception as e: |
| print(f"\rβ Error generating response: {e}") |
| |
| except KeyboardInterrupt: |
| print("\n\nπ Chat interrupted. Goodbye!") |
| break |
| except Exception as e: |
| print(f"\nβ Unexpected error: {e}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|