calcal / app.py
workcha's picture
Upload 19 files
61b4ff5 verified
"""
6์ฃผ์ฐจ ์‹ค์Šต: HuggingFace Space ์นผ๋กœ๋ฆฌ ์นด์šดํ„ฐ (LangChain LCEL ยท ํ•™์ƒ ๋ฒ„์ „)
=====================================================================
์Œ์‹ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜๋ฉด
1) HF Inference API์˜ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ๋กœ ์Œ์‹์„ ์ธ์‹ํ•˜๊ณ 
2) ๊ทธ ๊ฒฐ๊ณผ๋ฅผ LangChain ChatHuggingFace LLM์— ๋„˜๊ฒจ ์นผ๋กœ๋ฆฌ/์˜์–‘์†Œ๋ฅผ ์ถ”์ •ํ•œ ๋’ค
3) Gradio UI๋กœ ๋ณด์—ฌ์ค€๋‹ค.
ํ•ต์‹ฌ ๋ณ€๊ฒฝ: estimate_calories๋Š” LCEL ์ฒด์ธ(prompt | llm | parser)์œผ๋กœ ๊ตฌ์„ฑํ•œ๋‹ค.
์ด ํŒŒ์ผ์„ ๊ทธ๋Œ€๋กœ HuggingFace Space(Gradio SDK)์— ์˜ฌ๋ฆฌ๋ฉด ๋ฐฐํฌ๋œ๋‹ค.
TODO ๋กœ ํ‘œ์‹œ๋œ ๋ถ€๋ถ„์„ ์ฑ„์›Œ ์™„์„ฑํ•œ ๋’ค,
1) ๋กœ์ปฌ์—์„œ ์‹คํ–‰ํ•ด๋ณด๊ณ 
2) HuggingFace Space์— ๋ฐฐํฌํ•œ๋‹ค.
๋กœ์ปฌ ์‹คํ–‰:
uv run python app.py
"""
from __future__ import annotations
import json
import os
import tempfile
from typing import Any
import gradio as gr
from gradio_client import utils as _gc_utils # noqa: E402
# --- workaround: gradio_client์˜ JSON Schema walker๊ฐ€ bool ์Šคํ‚ค๋งˆ๋ฅผ ๋งŒ๋‚˜๋ฉด
# ํ„ฐ์ง€๋Š” ๋ฒ„๊ทธ(#10178) ์šฐํšŒ. Label/JSON ์ปดํฌ๋„ŒํŠธ๊ฐ€ ์ƒ์„ฑํ•˜๋Š”
# additionalProperties: true ์Šคํ‚ค๋งˆ์—์„œ ๋ฐœ์ƒํ•œ๋‹ค.
_orig_get_type = _gc_utils.get_type
def _safe_get_type(schema):
if isinstance(schema, bool):
return "Any"
return _orig_get_type(schema)
_gc_utils.get_type = _safe_get_type
_orig_j2p = _gc_utils._json_schema_to_python_type
def _safe_j2p(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_j2p(schema, defs)
_gc_utils._json_schema_to_python_type = _safe_j2p
from dotenv import load_dotenv
from huggingface_hub import InferenceClient
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from PIL import Image
from model_config import LLM_MODEL, VISION_MODEL, get_token
load_dotenv()
TOP_K = 3
# ---------------------------------------------------------------------------
# TODO 1. ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
# ---------------------------------------------------------------------------
# LLM ์ด '์˜์–‘์‚ฌ AI' ์—ญํ• ์„ ํ•˜๊ณ , 1์ธ๋ถ„ ๊ธฐ์ค€ ์นผ๋กœ๋ฆฌ/ํƒ„๋‹จ์ง€๋ฅผ JSON ์œผ๋กœ ์ถœ๋ ฅํ•˜๋„๋ก
# ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ๋ฅผ ์ž‘์„ฑํ•˜๋ผ.
# - ๋ฐ˜๋“œ์‹œ ์•„๋ž˜ JSON ์Šคํ‚ค๋งˆ๋งŒ ์ถœ๋ ฅํ•˜๋ผ๊ณ  ๊ฐ•์ œํ•  ๊ฒƒ
# {"food": str, "confidence": float, "calories_kcal": int,
# "carbs_g": int, "protein_g": int, "fat_g": int, "note": str}
# - ChatPromptTemplate ์— ๋“ค์–ด๊ฐ€๋ฏ€๋กœ JSON ์˜ˆ์‹œ์˜ ์ค‘๊ด„ํ˜ธ๋Š” {{ }} ๋กœ ์ด์Šค์ผ€์ดํ”„ํ•  ๊ฒƒ
SYSTEM_PROMPT = """ ๋„ˆ๋Š” ํ•œ๊ตญ ์˜์–‘์‚ฌ AI๋‹ค.
์‚ฌ์šฉ์ž๊ฐ€ ์Œ์‹ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ(top-k labels)๋ฅผ ์ฃผ๋ฉด,
๊ฐ€์žฅ ๊ฐ€๋Šฅ์„ฑ ๋†’์€ ์Œ์‹ 1๊ฐœ์˜ 1์ธ๋ถ„ ๊ธฐ์ค€ ์˜์–‘์ •๋ณด๋ฅผ ์ถ”์ •ํ•ด
๋ฐ˜๋“œ์‹œ ๋‹ค์Œ JSON ์Šคํ‚ค๋งˆ๋งŒ ์ถœ๋ ฅํ•˜๋ผ. ๋‹ค๋ฅธ ํ…์ŠคํŠธ/๋งˆํฌ๋‹ค์šด ๊ธˆ์ง€.
{{"food": "์Œ์‹๋ช…", "confidence": 0.0~1.0,
"calories_kcal": ์ •์ˆ˜, "carbs_g": ์ •์ˆ˜,
"protein_g": ์ •์ˆ˜, "fat_g": ์ •์ˆ˜,
"note": "์ถ”์ • ๊ทผ๊ฑฐ ํ•œ ์ค„"}}"
"""
# -----------------------------------------------------------------------------
# ํด๋ผ์ด์–ธํŠธ / ์ฒด์ธ lazy init
# -----------------------------------------------------------------------------
_vision_client: InferenceClient | None = None
_chain = None
def _vision_lazy() -> InferenceClient:
global _vision_client
if _vision_client is None:
_vision_client = InferenceClient(token=get_token())
return _vision_client
def _chain_lazy():
"""LCEL ์ฒด์ธ: prompt | ChatHuggingFace | JsonOutputParser"""
global _chain
if _chain is None:
# 3-1. HF Inference Endpoint ์ƒ์„ฑ
endpoint = HuggingFaceEndpoint(
repo_id=LLM_MODEL,
task="text-generation",
max_new_tokens=300,
temperature=0.2,
huggingfacehub_api_token=get_token(),
)
# 3-2. ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค๋กœ ๊ฐ์‹ธ๊ธฐ
llm = ChatHuggingFace(llm=endpoint)
# 3-3. ํ”„๋กฌํ”„ํŠธ ํ…œํ”Œ๋ฆฟ
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "๋‹ค์Œ์€ ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๊ธฐ์˜ top-k ๊ฒฐ๊ณผ๋‹ค:\n{labels_json}"),
])
# 3-4. LCEL ํŒŒ์ดํ”„๋ผ์ธ โ€” ์ด ํ•œ ์ค„์ด ํ•ต์‹ฌ!
_chain = prompt | llm | JsonOutputParser()
return _chain
# -----------------------------------------------------------------------------
# Step 1: ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ (LangChain ์ถ”์ƒํ™” ์—†์Œ โ€” InferenceClient ์ง์ ‘ ์‚ฌ์šฉ)
# -----------------------------------------------------------------------------
def classify_food(image: Image.Image) -> list[dict[str, Any]]:
"""HF ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ชจ๋ธ์— PIL ์ด๋ฏธ์ง€๋ฅผ ๋„˜๊ฒจ top-k ๊ฒฐ๊ณผ๋ฅผ ๋ฐ›๋Š”๋‹ค."""
client = _vision_lazy()
# PIL ์ด๋ฏธ์ง€๋ฅผ JPEG ์ž„์‹œ ํŒŒ์ผ๋กœ ์ €์žฅ (hf-inference ๋ผ์šฐํ„ฐ๊ฐ€ Content-Type ์„ ์š”๊ตฌ).
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp:
image.convert("RGB").save(tmp, format="JPEG")
tmp_path = tmp.name
try:
raw = client.image_classification(tmp_path, model=VISION_MODEL)
finally:
os.unlink(tmp_path)
results: list[dict[str, Any]] = []
for item in raw[:TOP_K]:
if isinstance(item, dict):
results.append({"label": item["label"], "score": float(item["score"])})
else:
results.append({"label": item.label, "score": float(item.score)})
return results
# -----------------------------------------------------------------------------
# Step 2: ์นผ๋กœ๋ฆฌ/์˜์–‘์†Œ ์ถ”์ • (LCEL ์ฒด์ธ)
# -----------------------------------------------------------------------------
def estimate_calories(labels: list[dict[str, Any]]) -> dict[str, Any]:
chain = _chain_lazy()
labels_json = json.dumps(labels, ensure_ascii=False)
try:
chain = _chain_lazy()
labels_json = json.dumps(labels, ensure_ascii=False)
return chain.invoke({"labels_json": labels_json})
except Exception as e:
return {
"food": labels[0]["label"] if labels else "unknown",
"confidence": labels[0]["score"] if labels else 0.0,
"calories_kcal": 0,
"carbs_g": 0,
"protein_g": 0,
"fat_g": 0,
"note": f"์ฒด์ธ ์‹คํ–‰ ์‹คํŒจ: {type(e).__name__}: {str(e)[:120]}",
}
# -----------------------------------------------------------------------------
# Step 3: Gradio ์ฝœ๋ฐฑ
# -----------------------------------------------------------------------------
def analyze(image):
if image is None:
return {}, {"error": "์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”."}
labels = classify_food(image)
label_view = {item["label"]: item["score"] for item in labels}
nutrition = estimate_calories(labels)
return label_view, nutrition
# -----------------------------------------------------------------------------
# Step 4: UI
# -----------------------------------------------------------------------------
def build_ui() -> gr.Interface:
return gr.Interface(
fn=analyze,
inputs=gr.Image(type="pil", label="์Œ์‹ ์‚ฌ์ง„ ์—…๋กœ๋“œ"),
outputs=[
gr.Label(num_top_classes=TOP_K, label="์Œ์‹ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ"),
gr.JSON(label="์นผ๋กœ๋ฆฌ & ์˜์–‘์†Œ ์ถ”์ •"),
],
title="๐Ÿฑ HuggingFace Calorie Counter (LangChain LCEL)",
description=(
"์Œ์‹ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜๋ฉด HF Inference API๋กœ ์Œ์‹์„ ์ธ์‹ํ•˜๊ณ , "
"LangChain LCEL ์ฒด์ธ์ด 1์ธ๋ถ„ ๊ธฐ์ค€ ์นผ๋กœ๋ฆฌ/์˜์–‘์†Œ๋ฅผ ์ถ”์ •ํ•ฉ๋‹ˆ๋‹ค. "
"๊ฒฐ๊ณผ๋Š” ์ฐธ๊ณ ์šฉ์ž…๋‹ˆ๋‹ค."
),
flagging_mode="never",
)
# ๋ชจ๋“ˆ ๋ ˆ๋ฒจ demo (Space/HF ๋Ÿฐํƒ€์ž„ ํ˜ธํ™˜)
demo = build_ui()
if __name__ == "__main__":
# HF Space์—์„œ๋Š” SPACE_ID ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋ผ ์žˆ์–ด 0.0.0.0 ๋ฐ”์ธ๋”ฉ์ด ํ•„์š”ํ•˜๋‹ค.
# ๋กœ์ปฌ์—์„œ๋Š” 127.0.0.1.
is_space = bool(os.getenv("SPACE_ID"))
demo.launch(
server_name="0.0.0.0" if is_space else "127.0.0.1",
server_port=int(os.getenv("PORT", 7860)),
show_api=False,
ssr_mode=False,
)