File size: 4,678 Bytes
5187368
d60da4f
df78c68
e06dc15
 
 
 
 
375924d
e06dc15
 
5187368
e06dc15
 
 
 
 
 
 
375924d
5187368
e06dc15
 
 
 
5187368
375924d
5187368
375924d
e06dc15
 
 
375924d
e06dc15
 
 
 
 
 
 
 
d60da4f
 
 
 
e06dc15
 
 
 
 
 
 
 
 
 
 
 
 
 
e7cf650
5187368
 
d60da4f
e7cf650
d60da4f
 
e06dc15
 
 
d60da4f
e06dc15
 
 
 
375924d
9ad188a
 
 
d60da4f
 
 
 
9ad188a
 
 
 
 
 
e06dc15
 
df78c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e06dc15
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Two-tier LLM client — primary / fallback, both Ollama Cloud over OpenAI-compatible HTTP.
import re
from collections.abc import Iterator
from functools import lru_cache
from typing import Any

from openai import OpenAI

from backend.config.settings import settings


@lru_cache(maxsize=2)
def _build_client(base_url: str, api_key: str) -> OpenAI:
    return OpenAI(base_url=base_url, api_key=api_key)


def get_client(tier: str | None = None) -> OpenAI:
    resolved = tier or settings.active_llm_tier
    if resolved == "fallback":
        return _build_client(settings.fallback_base_url, settings.fallback_api_key)
    return _build_client(settings.primary_base_url, settings.primary_api_key)


def active_model(tier: str | None = None) -> str:
    resolved = tier or settings.active_llm_tier
    models = {"primary": settings.primary_model, "fallback": settings.fallback_model}
    if resolved not in models:
        raise ValueError(f"Unknown LLM tier: '{resolved}'. Must be primary/fallback.")
    return models[resolved]


def _apply_no_think(messages: list[dict]) -> list[dict]:
    # Prepend /no_think to first user message (Ollama thinking suppression).
    result = list(messages)
    for i, msg in enumerate(result):
        if msg.get("role") == "user":
            result[i] = {**msg, "content": f"/no_think\n\n{msg['content']}"}
            break
    return result


def _strip_think_tags(text: str) -> str:
    return re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()


def chat_complete(
    messages: list[dict],
    max_tokens: int,
    tier: str | None = None,
    temperature: float = 0.7,
    **kwargs: Any,
) -> str:
    resolved_tier = tier or settings.active_llm_tier
    model = active_model(resolved_tier)
    client = get_client(resolved_tier)

    patched_messages = messages
    extra_body: dict[str, Any] = kwargs.pop("extra_body", {})

    if settings.thinking_mode == "suppress":
        patched_messages = _apply_no_think(messages)

    effective_max_tokens = max_tokens
    if settings.thinking_mode in ("strip", "full"):
        effective_max_tokens = max_tokens + settings.thinking_token_budget

    resp = client.chat.completions.create(
        model=model,
        messages=patched_messages,
        max_tokens=effective_max_tokens,
        temperature=temperature,
        extra_body=extra_body or None,
        **kwargs,
    )
    raw = (resp.choices[0].message.content if resp.choices else "") or ""
    print(
        f"[llm_client] tier={resolved_tier} model={model} raw_len={len(raw)} raw={raw[:200]!r}"
    )

    if settings.thinking_mode in ("off", "strip"):
        raw = _strip_think_tags(raw)

    stripped = raw.strip()
    if not stripped:
        print(
            f"[llm_client] WARNING: empty response after strip. finish_reason={resp.choices[0].finish_reason if resp.choices else 'none'}"
        )
    return stripped


def chat_complete_stream(
    messages: list[dict],
    max_tokens: int,
    tier: str | None = None,
    temperature: float = 0.7,
    **kwargs: Any,
) -> Iterator[str]:
    """Yield token deltas as they arrive. Thinking-mode stripping is applied
    post-hoc on the buffered text by the caller — streaming <think>…</think>
    into the UI would confuse the picker anyway.
    """
    resolved_tier = tier or settings.active_llm_tier
    model = active_model(resolved_tier)
    client = get_client(resolved_tier)

    patched_messages = messages
    extra_body: dict[str, Any] = kwargs.pop("extra_body", {})

    if settings.thinking_mode == "suppress":
        patched_messages = _apply_no_think(messages)

    effective_max_tokens = max_tokens
    if settings.thinking_mode in ("strip", "full"):
        effective_max_tokens = max_tokens + settings.thinking_token_budget

    stream = client.chat.completions.create(
        model=model,
        messages=patched_messages,
        max_tokens=effective_max_tokens,
        temperature=temperature,
        stream=True,
        extra_body=extra_body or None,
        **kwargs,
    )
    for chunk in stream:
        if not chunk.choices:
            continue
        delta = chunk.choices[0].delta
        piece = getattr(delta, "content", None) or ""
        if piece:
            yield piece


def finalize_streamed(text: str) -> str:
    """Apply the same post-processing chat_complete does once a stream is done."""
    if settings.thinking_mode in ("off", "strip"):
        text = _strip_think_tags(text)
    return text.strip()


def warmup(tier: str | None = None) -> None:
    chat_complete(
        messages=[{"role": "user", "content": "hi"}],
        max_tokens=5,
        tier=tier,
        temperature=0.0,
    )