brainworm2024 commited on
Commit
bece13e
·
1 Parent(s): f21e0aa

Connect to real AMD MI300X via vLLM

Browse files
Files changed (1) hide show
  1. src/inference/qwen.rs +51 -11
src/inference/qwen.rs CHANGED
@@ -1,13 +1,53 @@
1
  use anyhow::Result;
 
2
 
3
- /// Mock inference for local testing / HF Space CPU.
4
- /// Returns (generated_text, model_used, device_info)
5
- pub async fn generate(_redacted_prompt: &str) -> Result<(String, String, String)> {
6
- tracing::info!("[MOCK] Inference skipped – returning placeholder");
7
- tokio::time::sleep(std::time::Duration::from_millis(10)).await;
8
- Ok((
9
- "Triage result: non‑urgent (mock)".to_string(),
10
- "mock".to_string(),
11
- "CPU (mock)".to_string(),
12
- ))
13
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  use anyhow::Result;
2
+ use serde_json::Value;
3
 
4
+ /// Call the real vLLM endpoint on AMD MI300X.
5
+ /// Falls back to mock if the GPU is unreachable.
6
+ pub async fn generate(redacted_prompt: &str) -> Result<(String, String, String)> {
7
+ let vllm_url = std::env::var("VLLM_URL")
8
+ .unwrap_or_else(|_| "http://localhost:8000/v1/completions".to_string());
9
+ let api_key = std::env::var("VLLM_API_KEY")
10
+ .unwrap_or_else(|_| "abc-123".to_string());
11
+
12
+ let client = reqwest::Client::new();
13
+ let resp = client
14
+ .post(&vllm_url)
15
+ .header("Authorization", format!("Bearer {}", api_key))
16
+ .json(&serde_json::json!({
17
+ "model": "Qwen/Qwen2.5-7B-Instruct",
18
+ "prompt": redacted_prompt,
19
+ "max_tokens": 250,
20
+ "temperature": 0.7
21
+ }))
22
+ .send()
23
+ .await;
24
+
25
+ match resp {
26
+ Ok(r) if r.status().is_success() => {
27
+ let json: Value = r.json().await?;
28
+ let text = json["choices"][0]["text"]
29
+ .as_str()
30
+ .unwrap_or("No output")
31
+ .trim()
32
+ .to_string();
33
+ tracing::info!("vLLM inference completed on MI300X ({} chars)", text.len());
34
+ Ok((text, "7B (vLLM)".to_string(), "ROCm/MI300X".to_string()))
35
+ }
36
+ Ok(r) => {
37
+ tracing::warn!("vLLM returned {} – falling back to mock", r.status());
38
+ Ok((
39
+ "Triage result: non‑urgent (mock – GPU unavailable)".to_string(),
40
+ "mock".to_string(),
41
+ "CPU (fallback)".to_string(),
42
+ ))
43
+ }
44
+ Err(e) => {
45
+ tracing::warn!("vLLM unreachable: {} – falling back to mock", e);
46
+ Ok((
47
+ "Triage result: non‑urgent (mock – GPU unavailable)".to_string(),
48
+ "mock".to_string(),
49
+ "CPU (fallback)".to_string(),
50
+ ))
51
+ }
52
+ }
53
+ }