| """ |
| Inference script for testing the fine-tuned telecom intent model. |
| Loads LoRA adapters and generates network configurations from natural language intents. |
| |
| Usage on Kaggle: |
| python inference.py --intent "Deploy a low latency slice for autonomous drones in the harbor zone" |
| |
| Or run with a file of intents: |
| python inference.py --input_file intents.txt --output_file configs.json |
| """ |
|
|
| import argparse |
| import json |
| import os |
| import re |
| import sys |
|
|
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from peft import PeftModel |
|
|
| |
| |
| |
|
|
| BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" |
| ADAPTER_PATH = "./qwen2.5-7b-telecom-intent-lora" |
| MAX_NEW_TOKENS = 1024 |
| TEMPERATURE = 0.1 |
| TOP_P = 0.95 |
|
|
|
|
| def load_model(adapter_path: str, base_model: str): |
| """Load base model + LoRA adapters.""" |
| adapter_path = os.path.abspath(adapter_path) |
| if not os.path.isdir(adapter_path): |
| print(f"ERROR: Adapter path not found: {adapter_path}") |
| print("Run train.py first to generate adapters.") |
| sys.exit(1) |
|
|
| print(f"Loading base model: {base_model}") |
| model = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| dtype=torch.float16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
| print(f"Loading LoRA adapters: {adapter_path}") |
| model = PeftModel.from_pretrained(model, adapter_path) |
| model.eval() |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| base_model, |
| trust_remote_code=True, |
| padding_side="left", |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| print("Model ready!") |
| return model, tokenizer |
|
|
|
|
| def generate_config(model, tokenizer, intent_text: str) -> str: |
| """Generate a network configuration from a natural language intent.""" |
| messages = [ |
| { |
| "role": "system", |
| "content": ( |
| "You are a 5G/6G network orchestrator. " |
| "Given a natural language network intent, output a valid, " |
| "spec-compliant JSON network configuration. " |
| "Do not include any explanation — only the JSON configuration." |
| ), |
| }, |
| {"role": "user", "content": intent_text}, |
| ] |
|
|
| prompt = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| ) |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=MAX_NEW_TOKENS, |
| temperature=TEMPERATURE, |
| top_p=TOP_P, |
| do_sample=True, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| generated = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| |
| response = generated[len(prompt):].strip() |
|
|
| |
| json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", response, re.DOTALL) |
| if json_match: |
| response = json_match.group(1) |
|
|
| return response |
|
|
|
|
| def validate_json(text: str) -> tuple[bool, dict | None]: |
| """Try to parse response as JSON. Returns (success, parsed).""" |
| try: |
| text = text.strip() |
| start = text.find("{") |
| end = text.rfind("}") |
| if start != -1 and end != -1 and end > start: |
| text = text[start:end + 1] |
| parsed = json.loads(text) |
| return True, parsed |
| except json.JSONDecodeError: |
| return False, None |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Telecom Intent Inference") |
| parser.add_argument( |
| "--intent", |
| type=str, |
| default=None, |
| help="Single natural language intent string", |
| ) |
| parser.add_argument( |
| "--input_file", |
| type=str, |
| default=None, |
| help="File with one intent per line", |
| ) |
| parser.add_argument( |
| "--output_file", |
| type=str, |
| default="generated_configs.json", |
| help="Output JSON file for batch results", |
| ) |
| parser.add_argument( |
| "--adapter_path", |
| type=str, |
| default=ADAPTER_PATH, |
| help="Path to LoRA adapters", |
| ) |
| parser.add_argument( |
| "--base_model", |
| type=str, |
| default=BASE_MODEL, |
| help="Base model name", |
| ) |
| args = parser.parse_args() |
|
|
| model, tokenizer = load_model(args.adapter_path, args.base_model) |
|
|
| intents = [] |
| if args.intent: |
| intents = [args.intent] |
| elif args.input_file: |
| with open(args.input_file, "r") as f: |
| intents = [line.strip() for line in f if line.strip()] |
| else: |
| |
| print("\nInteractive mode. Type 'quit' to exit.") |
| while True: |
| user_input = input("\nIntent> ") |
| if user_input.lower() in ("quit", "exit", "q"): |
| break |
| config = generate_config(model, tokenizer, user_input) |
| is_valid, parsed = validate_json(config) |
| print(f"\n{'=' * 60}") |
| print(f"Generated Config (valid={is_valid}):") |
| print(f"{'=' * 60}") |
| if is_valid: |
| print(json.dumps(parsed, indent=2)) |
| else: |
| print(config) |
| return |
|
|
| |
| results = [] |
| valid_count = 0 |
| for i, intent in enumerate(intents): |
| print(f"\n[{i + 1}/{len(intents)}] Processing: {intent[:80]}...") |
| config = generate_config(model, tokenizer, intent) |
| is_valid, parsed = validate_json(config) |
| if is_valid: |
| valid_count += 1 |
|
|
| results.append({ |
| "intent": intent, |
| "generated_config": parsed if is_valid else config, |
| "json_valid": is_valid, |
| }) |
|
|
| |
| with open(args.output_file, "w") as f: |
| json.dump(results, f, indent=2) |
|
|
| print(f"\n{'=' * 60}") |
| print(f"Batch complete: {valid_count}/{len(intents)} valid JSON configs") |
| print(f"Results saved to: {args.output_file}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|