telecom-intent-pipeline / inference.py
nraptisss's picture
Upload inference.py
015111a verified
"""
Inference script for testing the fine-tuned telecom intent model.
Loads LoRA adapters and generates network configurations from natural language intents.
Usage on Kaggle:
python inference.py --intent "Deploy a low latency slice for autonomous drones in the harbor zone"
Or run with a file of intents:
python inference.py --input_file intents.txt --output_file configs.json
"""
import argparse
import json
import os
import re
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# ============================================================================
# CONFIGURATION
# ============================================================================
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" # must match train.py
ADAPTER_PATH = "./qwen2.5-7b-telecom-intent-lora" # output from train.py
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.1 # low for deterministic config generation
TOP_P = 0.95
def load_model(adapter_path: str, base_model: str):
"""Load base model + LoRA adapters."""
adapter_path = os.path.abspath(adapter_path)
if not os.path.isdir(adapter_path):
print(f"ERROR: Adapter path not found: {adapter_path}")
print("Run train.py first to generate adapters.")
sys.exit(1)
print(f"Loading base model: {base_model}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print(f"Loading LoRA adapters: {adapter_path}")
model = PeftModel.from_pretrained(model, adapter_path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=True,
padding_side="left",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model ready!")
return model, tokenizer
def generate_config(model, tokenizer, intent_text: str) -> str:
"""Generate a network configuration from a natural language intent."""
messages = [
{
"role": "system",
"content": (
"You are a 5G/6G network orchestrator. "
"Given a natural language network intent, output a valid, "
"spec-compliant JSON network configuration. "
"Do not include any explanation — only the JSON configuration."
),
},
{"role": "user", "content": intent_text},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response (after the prompt)
response = generated[len(prompt):].strip()
# Try to extract JSON if wrapped in markdown
json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", response, re.DOTALL)
if json_match:
response = json_match.group(1)
return response
def validate_json(text: str) -> tuple[bool, dict | None]:
"""Try to parse response as JSON. Returns (success, parsed)."""
try:
text = text.strip()
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
text = text[start:end + 1]
parsed = json.loads(text)
return True, parsed
except json.JSONDecodeError:
return False, None
def main():
parser = argparse.ArgumentParser(description="Telecom Intent Inference")
parser.add_argument(
"--intent",
type=str,
default=None,
help="Single natural language intent string",
)
parser.add_argument(
"--input_file",
type=str,
default=None,
help="File with one intent per line",
)
parser.add_argument(
"--output_file",
type=str,
default="generated_configs.json",
help="Output JSON file for batch results",
)
parser.add_argument(
"--adapter_path",
type=str,
default=ADAPTER_PATH,
help="Path to LoRA adapters",
)
parser.add_argument(
"--base_model",
type=str,
default=BASE_MODEL,
help="Base model name",
)
args = parser.parse_args()
model, tokenizer = load_model(args.adapter_path, args.base_model)
intents = []
if args.intent:
intents = [args.intent]
elif args.input_file:
with open(args.input_file, "r") as f:
intents = [line.strip() for line in f if line.strip()]
else:
# Interactive mode
print("\nInteractive mode. Type 'quit' to exit.")
while True:
user_input = input("\nIntent> ")
if user_input.lower() in ("quit", "exit", "q"):
break
config = generate_config(model, tokenizer, user_input)
is_valid, parsed = validate_json(config)
print(f"\n{'=' * 60}")
print(f"Generated Config (valid={is_valid}):")
print(f"{'=' * 60}")
if is_valid:
print(json.dumps(parsed, indent=2))
else:
print(config)
return
# Batch processing
results = []
valid_count = 0
for i, intent in enumerate(intents):
print(f"\n[{i + 1}/{len(intents)}] Processing: {intent[:80]}...")
config = generate_config(model, tokenizer, intent)
is_valid, parsed = validate_json(config)
if is_valid:
valid_count += 1
results.append({
"intent": intent,
"generated_config": parsed if is_valid else config,
"json_valid": is_valid,
})
# Save results
with open(args.output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\n{'=' * 60}")
print(f"Batch complete: {valid_count}/{len(intents)} valid JSON configs")
print(f"Results saved to: {args.output_file}")
if __name__ == "__main__":
main()