nraptisss commited on
Commit
015111a
·
verified ·
1 Parent(s): 512e317

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +9 -5
inference.py CHANGED
@@ -11,6 +11,7 @@ Or run with a file of intents:
11
 
12
  import argparse
13
  import json
 
14
  import re
15
  import sys
16
 
@@ -31,10 +32,16 @@ TOP_P = 0.95
31
 
32
  def load_model(adapter_path: str, base_model: str):
33
  """Load base model + LoRA adapters."""
 
 
 
 
 
 
34
  print(f"Loading base model: {base_model}")
35
  model = AutoModelForCausalLM.from_pretrained(
36
  base_model,
37
- torch_dtype=torch.float16,
38
  device_map="auto",
39
  trust_remote_code=True,
40
  )
@@ -92,7 +99,6 @@ def generate_config(model, tokenizer, intent_text: str) -> str:
92
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
93
 
94
  # Extract only the assistant's response (after the prompt)
95
- # For chat-templated output, we need to strip the input prompt
96
  response = generated[len(prompt):].strip()
97
 
98
  # Try to extract JSON if wrapped in markdown
@@ -106,16 +112,14 @@ def generate_config(model, tokenizer, intent_text: str) -> str:
106
  def validate_json(text: str) -> tuple[bool, dict | None]:
107
  """Try to parse response as JSON. Returns (success, parsed)."""
108
  try:
109
- # Remove any trailing non-JSON text
110
  text = text.strip()
111
- # Find first { and last }
112
  start = text.find("{")
113
  end = text.rfind("}")
114
  if start != -1 and end != -1 and end > start:
115
  text = text[start:end + 1]
116
  parsed = json.loads(text)
117
  return True, parsed
118
- except json.JSONDecodeError as e:
119
  return False, None
120
 
121
 
 
11
 
12
  import argparse
13
  import json
14
+ import os
15
  import re
16
  import sys
17
 
 
32
 
33
  def load_model(adapter_path: str, base_model: str):
34
  """Load base model + LoRA adapters."""
35
+ adapter_path = os.path.abspath(adapter_path)
36
+ if not os.path.isdir(adapter_path):
37
+ print(f"ERROR: Adapter path not found: {adapter_path}")
38
+ print("Run train.py first to generate adapters.")
39
+ sys.exit(1)
40
+
41
  print(f"Loading base model: {base_model}")
42
  model = AutoModelForCausalLM.from_pretrained(
43
  base_model,
44
+ dtype=torch.float16,
45
  device_map="auto",
46
  trust_remote_code=True,
47
  )
 
99
  generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
100
 
101
  # Extract only the assistant's response (after the prompt)
 
102
  response = generated[len(prompt):].strip()
103
 
104
  # Try to extract JSON if wrapped in markdown
 
112
  def validate_json(text: str) -> tuple[bool, dict | None]:
113
  """Try to parse response as JSON. Returns (success, parsed)."""
114
  try:
 
115
  text = text.strip()
 
116
  start = text.find("{")
117
  end = text.rfind("}")
118
  if start != -1 and end != -1 and end > start:
119
  text = text[start:end + 1]
120
  parsed = json.loads(text)
121
  return True, parsed
122
+ except json.JSONDecodeError:
123
  return False, None
124
 
125