File size: 2,263 Bytes
4a90885
 
99f62cc
 
 
eff14fc
 
 
 
 
 
 
 
 
5970ba6
 
4a90885
 
 
 
 
 
 
99f62cc
 
 
 
 
 
 
 
5970ba6
 
eff14fc
dcb7e13
 
 
99f62cc
 
4a90885
dcb7e13
4a90885
5970ba6
eff14fc
 
 
99f62cc
 
dcb7e13
99f62cc
 
4a90885
 
 
 
 
 
 
 
 
eff14fc
 
4a90885
eff14fc
4a90885
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use axum::{routing::{get, post}, Router, response::Json};
use serde::Serialize;
use tower_http::trace::TraceLayer;
use tracing_subscriber::EnvFilter;

mod handlers;
mod inference;
mod shield;
mod web3;
mod orchestrator;
mod proof;
mod api;
mod federation;

// use the handler directly from the crate root, not from a separate library
use handlers::triage;

#[derive(Serialize)]
struct StatusResponse {
    status: String,
    model: String,
    device: String,
}

#[tokio::main]
async fn main() -> anyhow::Result<()> {
    dotenvy::dotenv().ok();
    tracing_subscriber::fmt()
        .with_env_filter(EnvFilter::from_default_env().add_directive("rustvital_amd=debug".parse()?))
        .init();

    // Init zk-lite signing key
    proof::init_signing_key("0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef");

    let port = std::env::var("PORT").unwrap_or_else(|_| "3000".to_string());
    let addr = format!("0.0.0.0:{}", port);
    tracing::info!("Starting RustVital-AMD server on {}", addr);

    let app = Router::new()
        .route("/", get(serve_ui))
        .route("/health", get(|| async { "healthy" }))
        .route("/status", get(status))
        .route("/triage", post(triage::handle))
        .route("/triage/stream", get(api::stream::triage_stream))
        .route("/trigger-federated-tune", post(federation::trigger_tune))
        .route("/federation/round", get(federation::latest_round))
        .layer(TraceLayer::new_for_http());

    let listener = tokio::net::TcpListener::bind(&addr).await?;
    axum::serve(listener, app).await?;
    Ok(())
}

async fn serve_ui() -> axum::response::Html<&'static str> {
    axum::response::Html(include_str!("../static/index.html"))
}

async fn status() -> Json<StatusResponse> {
    let device = if std::env::var("ENABLE_ROCM").unwrap_or_default() == "1" {
        "ROCm/HIP (MI300X)"
    } else if std::env::var("VLLM_URL").is_ok() {
        "ROCm/MI300X (vLLM connected)"
    } else {
        "CPU (fallback)"
    };
    let model = std::env::var("FORCE_0_5B").map_or("7B (Qwen2.5-7B-Instruct)".to_string(), |_| "0.5B (Qwen2.5-0.5B-Instruct)".to_string());
    Json(StatusResponse {
        status: "running".to_string(),
        model,
        device: device.to_string(),
    })
}