physix / frontend /src /hooks /useLlmEpisodeRunner.ts
Pratyush-01's picture
Upload folder using huggingface_hub
0e24aff verified
/** Drives an LLM-backed episode: starts a session, then loops `llm-step`
* calls (run / pause / step-once / end) and accumulates `LlmTurn` records.
* Render-agnostic; consumed by `RunWithLlmPane`. */
import { useCallback, useEffect, useRef, useState } from "react";
import {
InteractiveApiError,
type InteractiveClient,
type LlmModelInfo,
type LlmStepResponse,
type SystemDescriptor,
} from "@/lib/interactiveClient";
import { InteractiveClient as DefaultClient } from "@/lib/interactiveClient";
import type { LlmConnection } from "@/lib/llmPresets";
import { OLLAMA_OPENAI_BASE_URL } from "@/lib/llmPresets";
import type {
PhysiXAction,
PhysiXObservation,
TrajectorySample,
} from "@/types/physix";
export type RunnerStatus =
| "idle"
| "starting"
| "running"
| "paused"
| "ended"
| "error";
export interface LlmTurn {
turn: number;
action: PhysiXAction;
observation: PhysiXObservation;
predictedTrajectory: TrajectorySample[];
rawCompletion: string;
latencyS: number;
model: string;
}
export interface LlmEpisodeRunnerState {
status: RunnerStatus;
errorMessage: string | null;
systems: SystemDescriptor[] | null;
/** Locally-pulled Ollama model tags (`null` = still loading). */
models: LlmModelInfo[] | null;
/** Set when the server couldn't talk to Ollama; UI surfaces a hint. */
modelsError: string | null;
/** Resolved system_id of the active episode (server-decided if user passed none). */
systemId: string | null;
sessionId: string | null;
/** Reset observation; trajectory + hint live here for the whole episode. */
initialObservation: PhysiXObservation | null;
turns: LlmTurn[];
maxTurns: number;
}
export interface LlmEpisodeRunnerControls {
refreshCatalogue: () => Promise<void>;
refreshModels: () => Promise<void>;
start: (options: {
systemId?: string | undefined;
seed?: number | undefined;
maxTurns?: number | undefined;
connection: LlmConnection;
temperature?: number | undefined;
}) => Promise<void>;
/** Pause an autoplaying loop without ending the session. */
pause: () => void;
/** Resume the loop from where it stopped. */
resume: () => Promise<void>;
/** Run one turn manually (also works while paused). */
stepOnce: () => Promise<void>;
/** End the session and clear local state. */
end: () => Promise<void>;
resetError: () => void;
}
const INITIAL_STATE: LlmEpisodeRunnerState = {
status: "idle",
errorMessage: null,
systems: null,
models: null,
modelsError: null,
systemId: null,
sessionId: null,
initialObservation: null,
turns: [],
maxTurns: 0,
};
interface RunnerSettings {
connection: LlmConnection;
temperature: number;
}
const DEFAULT_SETTINGS: RunnerSettings = {
connection: {
endpointId: "ollama",
baseUrl: OLLAMA_OPENAI_BASE_URL,
model: "qwen2.5:3b-instruct",
apiKey: "",
},
temperature: 0.7,
};
export function useLlmEpisodeRunner(
clientOverride?: InteractiveClient,
): LlmEpisodeRunnerState & LlmEpisodeRunnerControls {
const clientRef = useRef<InteractiveClient>(
clientOverride ?? new DefaultClient(),
);
const [state, setState] = useState<LlmEpisodeRunnerState>(INITIAL_STATE);
const sessionIdRef = useRef<string | null>(null);
const settingsRef = useRef<RunnerSettings>(DEFAULT_SETTINGS);
// Keeps autoplay loops idempotent: when the user pauses or the episode
// ends, we flip this so any in-flight chained call stops requesting more.
const stopRef = useRef<boolean>(false);
// ---- Catalogue (mount + refresh) ---------------------------------------
const refreshCatalogue = useCallback(async () => {
try {
const systems = await clientRef.current.listSystems();
setState((prev) => ({ ...prev, systems, errorMessage: null }));
} catch (error) {
setState((prev) => ({
...prev,
status: "error",
errorMessage: extractMessage(error),
}));
}
}, []);
// Fetched separately from the system catalogue: a missing Ollama daemon
// shouldn't blank out the system selector, and a missing systems catalogue
// shouldn't blank out the model selector.
const refreshModels = useCallback(async () => {
try {
const response = await clientRef.current.listModels();
setState((prev) => ({
...prev,
models: response.models,
modelsError: response.error ?? null,
}));
} catch (error) {
setState((prev) => ({
...prev,
models: [],
modelsError: extractMessage(error),
}));
}
}, []);
useEffect(() => {
void refreshCatalogue();
void refreshModels();
}, [refreshCatalogue, refreshModels]);
// Best-effort cleanup if the user closes the tab mid-session.
useEffect(() => {
const handler = () => {
const sessionId = sessionIdRef.current;
if (!sessionId) return;
const client = clientRef.current as unknown as { baseUrl: string };
const url = `${client.baseUrl}/interactive/sessions/${encodeURIComponent(sessionId)}`;
try {
navigator.sendBeacon?.(url);
} catch {
/* ignore */
}
};
window.addEventListener("beforeunload", handler);
return () => window.removeEventListener("beforeunload", handler);
}, []);
// ---- Helpers -----------------------------------------------------------
const recordTurn = useCallback((response: LlmStepResponse) => {
const turnRecord: LlmTurn = {
turn: response.observation.turn,
action: response.action,
observation: response.observation,
predictedTrajectory: response.predicted_trajectory,
rawCompletion: response.raw_completion,
latencyS: response.latency_s,
model: response.model,
};
setState((prev) => ({
...prev,
turns: [...prev.turns, turnRecord],
}));
return turnRecord;
}, []);
const callLlmStepOnce = useCallback(async (): Promise<LlmTurn | null> => {
const sessionId = sessionIdRef.current;
if (!sessionId) return null;
const { connection, temperature } = settingsRef.current;
try {
const response = await clientRef.current.llmStep(sessionId, {
base_url: connection.baseUrl,
model: connection.model,
api_key: connection.apiKey || undefined,
temperature,
});
return recordTurn(response);
} catch (error) {
setState((prev) => ({
...prev,
status: "error",
errorMessage: extractMessage(error),
}));
return null;
}
}, [recordTurn]);
const runUntilDone = useCallback(async () => {
setState((prev) => ({ ...prev, status: "running", errorMessage: null }));
stopRef.current = false;
while (!stopRef.current) {
const turn = await callLlmStepOnce();
if (turn === null) return;
if (turn.observation.done) {
stopRef.current = true;
setState((prev) => ({ ...prev, status: "ended" }));
return;
}
}
// We exited the loop because the user paused.
setState((prev) =>
prev.status === "running" ? { ...prev, status: "paused" } : prev,
);
}, [callLlmStepOnce]);
// ---- Controls ----------------------------------------------------------
const start = useCallback(
async (options: {
systemId?: string | undefined;
seed?: number | undefined;
maxTurns?: number | undefined;
connection: LlmConnection;
temperature?: number | undefined;
}) => {
// Tear down any prior session.
const prior = sessionIdRef.current;
if (prior) {
sessionIdRef.current = null;
try {
await clientRef.current.endSession(prior);
} catch {
/* best-effort */
}
}
stopRef.current = true;
settingsRef.current = {
connection: options.connection,
temperature: options.temperature ?? DEFAULT_SETTINGS.temperature,
};
setState((prev) => ({
...prev,
status: "starting",
errorMessage: null,
sessionId: null,
turns: [],
initialObservation: null,
}));
try {
const response = await clientRef.current.startSession({
system_id: options.systemId,
seed: options.seed,
max_turns: options.maxTurns,
});
sessionIdRef.current = response.session_id;
setState((prev) => ({
...prev,
status: "paused", // ready to run; caller decides when to start the loop
systemId: response.system.system_id,
sessionId: response.session_id,
initialObservation: response.observation,
maxTurns: response.max_turns,
turns: [],
}));
// Kick off the run-to-done loop. Caller can pause at any time.
void runUntilDone();
} catch (error) {
setState((prev) => ({
...prev,
status: "error",
errorMessage: extractMessage(error),
}));
}
},
[runUntilDone],
);
const pause = useCallback(() => {
stopRef.current = true;
setState((prev) =>
prev.status === "running" ? { ...prev, status: "paused" } : prev,
);
}, []);
const resume = useCallback(async () => {
if (!sessionIdRef.current) return;
if (stopRef.current === false) return; // already running
void runUntilDone();
}, [runUntilDone]);
const stepOnce = useCallback(async () => {
if (!sessionIdRef.current) return;
stopRef.current = true; // make sure no loop is queueing
setState((prev) => ({ ...prev, status: "running", errorMessage: null }));
const turn = await callLlmStepOnce();
if (turn === null) return;
setState((prev) => ({
...prev,
status: turn.observation.done ? "ended" : "paused",
}));
}, [callLlmStepOnce]);
const end = useCallback(async () => {
stopRef.current = true;
const sessionId = sessionIdRef.current;
sessionIdRef.current = null;
if (sessionId) {
try {
await clientRef.current.endSession(sessionId);
} catch {
/* best-effort */
}
}
setState((prev) => ({
...prev,
status: "idle",
sessionId: null,
turns: [],
initialObservation: null,
}));
}, []);
const resetError = useCallback(() => {
setState((prev) => ({ ...prev, errorMessage: null }));
}, []);
return {
...state,
refreshCatalogue,
refreshModels,
start,
pause,
resume,
stepOnce,
end,
resetError,
};
}
function extractMessage(error: unknown): string {
if (error instanceof InteractiveApiError) return error.detail;
if (error instanceof Error) return error.message;
return "Unknown error";
}