File size: 6,405 Bytes
bd023d4 015111a bd023d4 015111a bd023d4 015111a bd023d4 015111a bd023d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 | """
Inference script for testing the fine-tuned telecom intent model.
Loads LoRA adapters and generates network configurations from natural language intents.
Usage on Kaggle:
python inference.py --intent "Deploy a low latency slice for autonomous drones in the harbor zone"
Or run with a file of intents:
python inference.py --input_file intents.txt --output_file configs.json
"""
import argparse
import json
import os
import re
import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# ============================================================================
# CONFIGURATION
# ============================================================================
BASE_MODEL = "Qwen/Qwen2.5-7B-Instruct" # must match train.py
ADAPTER_PATH = "./qwen2.5-7b-telecom-intent-lora" # output from train.py
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.1 # low for deterministic config generation
TOP_P = 0.95
def load_model(adapter_path: str, base_model: str):
"""Load base model + LoRA adapters."""
adapter_path = os.path.abspath(adapter_path)
if not os.path.isdir(adapter_path):
print(f"ERROR: Adapter path not found: {adapter_path}")
print("Run train.py first to generate adapters.")
sys.exit(1)
print(f"Loading base model: {base_model}")
model = AutoModelForCausalLM.from_pretrained(
base_model,
dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
print(f"Loading LoRA adapters: {adapter_path}")
model = PeftModel.from_pretrained(model, adapter_path)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
base_model,
trust_remote_code=True,
padding_side="left",
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("Model ready!")
return model, tokenizer
def generate_config(model, tokenizer, intent_text: str) -> str:
"""Generate a network configuration from a natural language intent."""
messages = [
{
"role": "system",
"content": (
"You are a 5G/6G network orchestrator. "
"Given a natural language network intent, output a valid, "
"spec-compliant JSON network configuration. "
"Do not include any explanation — only the JSON configuration."
),
},
{"role": "user", "content": intent_text},
]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
top_p=TOP_P,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response (after the prompt)
response = generated[len(prompt):].strip()
# Try to extract JSON if wrapped in markdown
json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", response, re.DOTALL)
if json_match:
response = json_match.group(1)
return response
def validate_json(text: str) -> tuple[bool, dict | None]:
"""Try to parse response as JSON. Returns (success, parsed)."""
try:
text = text.strip()
start = text.find("{")
end = text.rfind("}")
if start != -1 and end != -1 and end > start:
text = text[start:end + 1]
parsed = json.loads(text)
return True, parsed
except json.JSONDecodeError:
return False, None
def main():
parser = argparse.ArgumentParser(description="Telecom Intent Inference")
parser.add_argument(
"--intent",
type=str,
default=None,
help="Single natural language intent string",
)
parser.add_argument(
"--input_file",
type=str,
default=None,
help="File with one intent per line",
)
parser.add_argument(
"--output_file",
type=str,
default="generated_configs.json",
help="Output JSON file for batch results",
)
parser.add_argument(
"--adapter_path",
type=str,
default=ADAPTER_PATH,
help="Path to LoRA adapters",
)
parser.add_argument(
"--base_model",
type=str,
default=BASE_MODEL,
help="Base model name",
)
args = parser.parse_args()
model, tokenizer = load_model(args.adapter_path, args.base_model)
intents = []
if args.intent:
intents = [args.intent]
elif args.input_file:
with open(args.input_file, "r") as f:
intents = [line.strip() for line in f if line.strip()]
else:
# Interactive mode
print("\nInteractive mode. Type 'quit' to exit.")
while True:
user_input = input("\nIntent> ")
if user_input.lower() in ("quit", "exit", "q"):
break
config = generate_config(model, tokenizer, user_input)
is_valid, parsed = validate_json(config)
print(f"\n{'=' * 60}")
print(f"Generated Config (valid={is_valid}):")
print(f"{'=' * 60}")
if is_valid:
print(json.dumps(parsed, indent=2))
else:
print(config)
return
# Batch processing
results = []
valid_count = 0
for i, intent in enumerate(intents):
print(f"\n[{i + 1}/{len(intents)}] Processing: {intent[:80]}...")
config = generate_config(model, tokenizer, intent)
is_valid, parsed = validate_json(config)
if is_valid:
valid_count += 1
results.append({
"intent": intent,
"generated_config": parsed if is_valid else config,
"json_valid": is_valid,
})
# Save results
with open(args.output_file, "w") as f:
json.dump(results, f, indent=2)
print(f"\n{'=' * 60}")
print(f"Batch complete: {valid_count}/{len(intents)} valid JSON configs")
print(f"Results saved to: {args.output_file}")
if __name__ == "__main__":
main()
|