Upload inference.py
Browse files- inference.py +9 -5
inference.py
CHANGED
|
@@ -11,6 +11,7 @@ Or run with a file of intents:
|
|
| 11 |
|
| 12 |
import argparse
|
| 13 |
import json
|
|
|
|
| 14 |
import re
|
| 15 |
import sys
|
| 16 |
|
|
@@ -31,10 +32,16 @@ TOP_P = 0.95
|
|
| 31 |
|
| 32 |
def load_model(adapter_path: str, base_model: str):
|
| 33 |
"""Load base model + LoRA adapters."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
print(f"Loading base model: {base_model}")
|
| 35 |
model = AutoModelForCausalLM.from_pretrained(
|
| 36 |
base_model,
|
| 37 |
-
|
| 38 |
device_map="auto",
|
| 39 |
trust_remote_code=True,
|
| 40 |
)
|
|
@@ -92,7 +99,6 @@ def generate_config(model, tokenizer, intent_text: str) -> str:
|
|
| 92 |
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 93 |
|
| 94 |
# Extract only the assistant's response (after the prompt)
|
| 95 |
-
# For chat-templated output, we need to strip the input prompt
|
| 96 |
response = generated[len(prompt):].strip()
|
| 97 |
|
| 98 |
# Try to extract JSON if wrapped in markdown
|
|
@@ -106,16 +112,14 @@ def generate_config(model, tokenizer, intent_text: str) -> str:
|
|
| 106 |
def validate_json(text: str) -> tuple[bool, dict | None]:
|
| 107 |
"""Try to parse response as JSON. Returns (success, parsed)."""
|
| 108 |
try:
|
| 109 |
-
# Remove any trailing non-JSON text
|
| 110 |
text = text.strip()
|
| 111 |
-
# Find first { and last }
|
| 112 |
start = text.find("{")
|
| 113 |
end = text.rfind("}")
|
| 114 |
if start != -1 and end != -1 and end > start:
|
| 115 |
text = text[start:end + 1]
|
| 116 |
parsed = json.loads(text)
|
| 117 |
return True, parsed
|
| 118 |
-
except json.JSONDecodeError
|
| 119 |
return False, None
|
| 120 |
|
| 121 |
|
|
|
|
| 11 |
|
| 12 |
import argparse
|
| 13 |
import json
|
| 14 |
+
import os
|
| 15 |
import re
|
| 16 |
import sys
|
| 17 |
|
|
|
|
| 32 |
|
| 33 |
def load_model(adapter_path: str, base_model: str):
|
| 34 |
"""Load base model + LoRA adapters."""
|
| 35 |
+
adapter_path = os.path.abspath(adapter_path)
|
| 36 |
+
if not os.path.isdir(adapter_path):
|
| 37 |
+
print(f"ERROR: Adapter path not found: {adapter_path}")
|
| 38 |
+
print("Run train.py first to generate adapters.")
|
| 39 |
+
sys.exit(1)
|
| 40 |
+
|
| 41 |
print(f"Loading base model: {base_model}")
|
| 42 |
model = AutoModelForCausalLM.from_pretrained(
|
| 43 |
base_model,
|
| 44 |
+
dtype=torch.float16,
|
| 45 |
device_map="auto",
|
| 46 |
trust_remote_code=True,
|
| 47 |
)
|
|
|
|
| 99 |
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 100 |
|
| 101 |
# Extract only the assistant's response (after the prompt)
|
|
|
|
| 102 |
response = generated[len(prompt):].strip()
|
| 103 |
|
| 104 |
# Try to extract JSON if wrapped in markdown
|
|
|
|
| 112 |
def validate_json(text: str) -> tuple[bool, dict | None]:
|
| 113 |
"""Try to parse response as JSON. Returns (success, parsed)."""
|
| 114 |
try:
|
|
|
|
| 115 |
text = text.strip()
|
|
|
|
| 116 |
start = text.find("{")
|
| 117 |
end = text.rfind("}")
|
| 118 |
if start != -1 and end != -1 and end > start:
|
| 119 |
text = text[start:end + 1]
|
| 120 |
parsed = json.loads(text)
|
| 121 |
return True, parsed
|
| 122 |
+
except json.JSONDecodeError:
|
| 123 |
return False, None
|
| 124 |
|
| 125 |
|