gaurv007 commited on
Commit
5453fed
·
verified ·
1 Parent(s): 640aecf

Upload alpha_factory/infra/llm_client.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. alpha_factory/infra/llm_client.py +102 -0
alpha_factory/infra/llm_client.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Client — unified interface to vLLM / Ollama with guided JSON generation.
3
+ All outputs are schema-constrained. No free-text alpha generation.
4
+ """
5
+ import asyncio
6
+ import json
7
+ from typing import TypeVar
8
+ from pydantic import BaseModel
9
+ from openai import AsyncOpenAI
10
+ from ..config import LLMConfig
11
+
12
+ T = TypeVar("T", bound=BaseModel)
13
+
14
+
15
+ class LLMClient:
16
+ """
17
+ Async LLM client with structured JSON output.
18
+ Connects to vLLM or Ollama (both expose OpenAI-compatible API).
19
+ """
20
+
21
+ def __init__(self, config: LLMConfig):
22
+ self.config = config
23
+ self.client = AsyncOpenAI(
24
+ base_url=config.base_url,
25
+ api_key=config.api_key,
26
+ )
27
+ self._token_count = 0
28
+
29
+ async def generate_json(
30
+ self,
31
+ prompt: str,
32
+ schema: type[T],
33
+ model: str | None = None,
34
+ temperature: float | None = None,
35
+ system_prompt: str = "You are a quantitative finance expert.",
36
+ ) -> T:
37
+ """
38
+ Generate a structured JSON response conforming to the given Pydantic schema.
39
+ Uses guided decoding via response_format (vLLM supports this natively).
40
+ """
41
+ model = model or self.config.mediumfish_model
42
+ temp = temperature or self.config.temperature_generation
43
+
44
+ # Build JSON schema for guided generation
45
+ json_schema = schema.model_json_schema()
46
+
47
+ response = await self.client.chat.completions.create(
48
+ model=model,
49
+ messages=[
50
+ {"role": "system", "content": system_prompt},
51
+ {"role": "user", "content": prompt},
52
+ ],
53
+ temperature=temp,
54
+ max_tokens=self.config.max_tokens,
55
+ response_format={
56
+ "type": "json_schema",
57
+ "json_schema": {
58
+ "name": schema.__name__,
59
+ "schema": json_schema,
60
+ },
61
+ },
62
+ )
63
+
64
+ content = response.choices[0].message.content
65
+ self._token_count += response.usage.total_tokens if response.usage else 0
66
+
67
+ # Parse and validate
68
+ data = json.loads(content)
69
+ return schema.model_validate(data)
70
+
71
+ async def generate_text(
72
+ self,
73
+ prompt: str,
74
+ model: str | None = None,
75
+ temperature: float | None = None,
76
+ system_prompt: str = "You are a quantitative finance expert.",
77
+ max_tokens: int = 2048,
78
+ ) -> str:
79
+ """Generate free-text response (for memos/reports only, never for expressions)."""
80
+ model = model or self.config.mediumfish_model
81
+ temp = temperature or self.config.temperature_critique
82
+
83
+ response = await self.client.chat.completions.create(
84
+ model=model,
85
+ messages=[
86
+ {"role": "system", "content": system_prompt},
87
+ {"role": "user", "content": prompt},
88
+ ],
89
+ temperature=temp,
90
+ max_tokens=max_tokens,
91
+ )
92
+
93
+ content = response.choices[0].message.content
94
+ self._token_count += response.usage.total_tokens if response.usage else 0
95
+ return content
96
+
97
+ @property
98
+ def tokens_used(self) -> int:
99
+ return self._token_count
100
+
101
+ def reset_token_count(self):
102
+ self._token_count = 0