Using model meta-llama/Llama-3.3-70B-Instruct. Adapted parsing of the response.
Browse files- src/streamlit_app.py +18 -11
src/streamlit_app.py
CHANGED
|
@@ -84,16 +84,21 @@ def extract_cities_with_llm(user_prompt: str) -> list[str]:
|
|
| 84 |
"""Use HF Inference API to extract city names from a natural-language prompt."""
|
| 85 |
client = InferenceClient(
|
| 86 |
# model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 87 |
-
model="mistralai/Mistral-7B-Instruct-v0.2",
|
| 88 |
# model="nvidia/Gemma-4-26B-A4B-NVFP4",
|
| 89 |
# model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 90 |
-
|
| 91 |
|
| 92 |
# token=st.secrets.get("HF_TOKEN", None),
|
| 93 |
# token=st.secrets.get("rainytrek010526001read", None),
|
| 94 |
# token=st.secrets.get("API_KEY", None),
|
| 95 |
# token=os.getenv("API_KEY")
|
| 96 |
token=os.getenv("rainytrek010526001read")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
|
| 99 |
)
|
|
@@ -111,19 +116,21 @@ def extract_cities_with_llm(user_prompt: str) -> list[str]:
|
|
| 111 |
]
|
| 112 |
|
| 113 |
response = client.chat_completion(messages=messages, max_tokens=256, temperature=0.1)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
|
| 117 |
raw = response.choices[0].message.content.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
-
#
|
| 120 |
-
if raw.startswith("```"):
|
| 121 |
-
raw = raw.split("```")[1]
|
| 122 |
-
if raw.startswith("json"):
|
| 123 |
-
raw = raw[4:]
|
| 124 |
-
raw = raw.strip()
|
| 125 |
|
|
|
|
|
|
|
| 126 |
cities = json.loads(raw)
|
|
|
|
| 127 |
return [c.strip() for c in cities if isinstance(c, str) and c.strip()]
|
| 128 |
|
| 129 |
|
|
@@ -235,7 +242,7 @@ if run and user_input.strip():
|
|
| 235 |
cities = []
|
| 236 |
|
| 237 |
if not cities:
|
| 238 |
-
st.markdown('<div class="error-box">
|
| 239 |
else:
|
| 240 |
st.markdown(f'<div class="llm-box"><div class="llm-label">Cities detected by LLM</div>{" · ".join(cities)}</div>', unsafe_allow_html=True)
|
| 241 |
|
|
|
|
| 84 |
"""Use HF Inference API to extract city names from a natural-language prompt."""
|
| 85 |
client = InferenceClient(
|
| 86 |
# model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 87 |
+
# model="mistralai/Mistral-7B-Instruct-v0.2",
|
| 88 |
# model="nvidia/Gemma-4-26B-A4B-NVFP4",
|
| 89 |
# model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 90 |
+
model="meta-llama/Llama-3.3-70B-Instruct",
|
| 91 |
|
| 92 |
# token=st.secrets.get("HF_TOKEN", None),
|
| 93 |
# token=st.secrets.get("rainytrek010526001read", None),
|
| 94 |
# token=st.secrets.get("API_KEY", None),
|
| 95 |
# token=os.getenv("API_KEY")
|
| 96 |
token=os.getenv("rainytrek010526001read")
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
provider="fireworks-ai"
|
| 102 |
|
| 103 |
|
| 104 |
)
|
|
|
|
| 116 |
]
|
| 117 |
|
| 118 |
response = client.chat_completion(messages=messages, max_tokens=256, temperature=0.1)
|
|
|
|
|
|
|
| 119 |
|
| 120 |
raw = response.choices[0].message.content.strip()
|
| 121 |
+
# # strip markdown fences if present
|
| 122 |
+
# if raw.startswith("```"):
|
| 123 |
+
# raw = raw.split("```")[1]
|
| 124 |
+
# if raw.startswith("json"):
|
| 125 |
+
# raw = raw[4:]
|
| 126 |
+
# raw = raw.strip()
|
| 127 |
|
| 128 |
+
# cities = json.loads(raw)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
+
## raw = [c.strip() for c in response["choices"][0]["message"]["content"].split("\"") if len(c.strip()) > 1]
|
| 131 |
+
|
| 132 |
cities = json.loads(raw)
|
| 133 |
+
|
| 134 |
return [c.strip() for c in cities if isinstance(c, str) and c.strip()]
|
| 135 |
|
| 136 |
|
|
|
|
| 242 |
cities = []
|
| 243 |
|
| 244 |
if not cities:
|
| 245 |
+
st.markdown('<div class="error-box">what cities in particular should I look for?</div>', unsafe_allow_html=True)
|
| 246 |
else:
|
| 247 |
st.markdown(f'<div class="llm-box"><div class="llm-label">Cities detected by LLM</div>{" · ".join(cities)}</div>', unsafe_allow_html=True)
|
| 248 |
|