annanurov commited on
Commit
2d007ea
·
verified ·
1 Parent(s): 5b72b63

Using model meta-llama/Llama-3.3-70B-Instruct. Adapted parsing of the response.

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -11
src/streamlit_app.py CHANGED
@@ -84,16 +84,21 @@ def extract_cities_with_llm(user_prompt: str) -> list[str]:
84
  """Use HF Inference API to extract city names from a natural-language prompt."""
85
  client = InferenceClient(
86
  # model="mistralai/Mistral-7B-Instruct-v0.3",
87
- model="mistralai/Mistral-7B-Instruct-v0.2",
88
  # model="nvidia/Gemma-4-26B-A4B-NVFP4",
89
  # model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
90
-
91
 
92
  # token=st.secrets.get("HF_TOKEN", None),
93
  # token=st.secrets.get("rainytrek010526001read", None),
94
  # token=st.secrets.get("API_KEY", None),
95
  # token=os.getenv("API_KEY")
96
  token=os.getenv("rainytrek010526001read")
 
 
 
 
 
97
 
98
 
99
  )
@@ -111,19 +116,21 @@ def extract_cities_with_llm(user_prompt: str) -> list[str]:
111
  ]
112
 
113
  response = client.chat_completion(messages=messages, max_tokens=256, temperature=0.1)
114
-
115
-
116
 
117
  raw = response.choices[0].message.content.strip()
 
 
 
 
 
 
118
 
119
- # strip markdown fences if present
120
- if raw.startswith("```"):
121
- raw = raw.split("```")[1]
122
- if raw.startswith("json"):
123
- raw = raw[4:]
124
- raw = raw.strip()
125
 
 
 
126
  cities = json.loads(raw)
 
127
  return [c.strip() for c in cities if isinstance(c, str) and c.strip()]
128
 
129
 
@@ -235,7 +242,7 @@ if run and user_input.strip():
235
  cities = []
236
 
237
  if not cities:
238
- st.markdown('<div class="error-box">No cities found in your prompt. Try mentioning specific city names.</div>', unsafe_allow_html=True)
239
  else:
240
  st.markdown(f'<div class="llm-box"><div class="llm-label">Cities detected by LLM</div>{" · ".join(cities)}</div>', unsafe_allow_html=True)
241
 
 
84
  """Use HF Inference API to extract city names from a natural-language prompt."""
85
  client = InferenceClient(
86
  # model="mistralai/Mistral-7B-Instruct-v0.3",
87
+ # model="mistralai/Mistral-7B-Instruct-v0.2",
88
  # model="nvidia/Gemma-4-26B-A4B-NVFP4",
89
  # model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
90
+ model="meta-llama/Llama-3.3-70B-Instruct",
91
 
92
  # token=st.secrets.get("HF_TOKEN", None),
93
  # token=st.secrets.get("rainytrek010526001read", None),
94
  # token=st.secrets.get("API_KEY", None),
95
  # token=os.getenv("API_KEY")
96
  token=os.getenv("rainytrek010526001read")
97
+
98
+
99
+
100
+
101
+ provider="fireworks-ai"
102
 
103
 
104
  )
 
116
  ]
117
 
118
  response = client.chat_completion(messages=messages, max_tokens=256, temperature=0.1)
 
 
119
 
120
  raw = response.choices[0].message.content.strip()
121
+ # # strip markdown fences if present
122
+ # if raw.startswith("```"):
123
+ # raw = raw.split("```")[1]
124
+ # if raw.startswith("json"):
125
+ # raw = raw[4:]
126
+ # raw = raw.strip()
127
 
128
+ # cities = json.loads(raw)
 
 
 
 
 
129
 
130
+ ## raw = [c.strip() for c in response["choices"][0]["message"]["content"].split("\"") if len(c.strip()) > 1]
131
+
132
  cities = json.loads(raw)
133
+
134
  return [c.strip() for c in cities if isinstance(c, str) and c.strip()]
135
 
136
 
 
242
  cities = []
243
 
244
  if not cities:
245
+ st.markdown('<div class="error-box">what cities in particular should I look for?</div>', unsafe_allow_html=True)
246
  else:
247
  st.markdown(f'<div class="llm-box"><div class="llm-label">Cities detected by LLM</div>{" · ".join(cities)}</div>', unsafe_allow_html=True)
248