| """ |
| 6์ฃผ์ฐจ ์ค์ต: HuggingFace Space ์นผ๋ก๋ฆฌ ์นด์ดํฐ (LangChain LCEL ยท ํ์ ๋ฒ์ ) |
| ===================================================================== |
| ์์ ์ฌ์ง์ ์
๋ก๋ํ๋ฉด |
| 1) HF Inference API์ ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ๋ก ์์์ ์ธ์ํ๊ณ |
| 2) ๊ทธ ๊ฒฐ๊ณผ๋ฅผ LangChain ChatHuggingFace LLM์ ๋๊ฒจ ์นผ๋ก๋ฆฌ/์์์๋ฅผ ์ถ์ ํ ๋ค |
| 3) Gradio UI๋ก ๋ณด์ฌ์ค๋ค. |
| |
| ํต์ฌ ๋ณ๊ฒฝ: estimate_calories๋ LCEL ์ฒด์ธ(prompt | llm | parser)์ผ๋ก ๊ตฌ์ฑํ๋ค. |
| ์ด ํ์ผ์ ๊ทธ๋๋ก HuggingFace Space(Gradio SDK)์ ์ฌ๋ฆฌ๋ฉด ๋ฐฐํฌ๋๋ค. |
| |
| TODO ๋ก ํ์๋ ๋ถ๋ถ์ ์ฑ์ ์์ฑํ ๋ค, |
| 1) ๋ก์ปฌ์์ ์คํํด๋ณด๊ณ |
| 2) HuggingFace Space์ ๋ฐฐํฌํ๋ค. |
| |
| ๋ก์ปฌ ์คํ: |
| uv run python app.py |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| import tempfile |
| from typing import Any |
|
|
| import gradio as gr |
| from gradio_client import utils as _gc_utils |
|
|
| |
| |
| |
| _orig_get_type = _gc_utils.get_type |
| def _safe_get_type(schema): |
| if isinstance(schema, bool): |
| return "Any" |
| return _orig_get_type(schema) |
| _gc_utils.get_type = _safe_get_type |
|
|
| _orig_j2p = _gc_utils._json_schema_to_python_type |
| def _safe_j2p(schema, defs=None): |
| if isinstance(schema, bool): |
| return "Any" |
| return _orig_j2p(schema, defs) |
| _gc_utils._json_schema_to_python_type = _safe_j2p |
|
|
| from dotenv import load_dotenv |
| from huggingface_hub import InferenceClient |
| from langchain_core.output_parsers import JsonOutputParser |
| from langchain_core.prompts import ChatPromptTemplate |
| from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint |
| from PIL import Image |
|
|
| from model_config import LLM_MODEL, VISION_MODEL, get_token |
|
|
| load_dotenv() |
|
|
| TOP_K = 3 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| SYSTEM_PROMPT = """ ๋๋ ํ๊ตญ ์์์ฌ AI๋ค. |
| ์ฌ์ฉ์๊ฐ ์์ ๋ถ๋ฅ ๊ฒฐ๊ณผ(top-k labels)๋ฅผ ์ฃผ๋ฉด, |
| ๊ฐ์ฅ ๊ฐ๋ฅ์ฑ ๋์ ์์ 1๊ฐ์ 1์ธ๋ถ ๊ธฐ์ค ์์์ ๋ณด๋ฅผ ์ถ์ ํด |
| ๋ฐ๋์ ๋ค์ JSON ์คํค๋ง๋ง ์ถ๋ ฅํ๋ผ. ๋ค๋ฅธ ํ
์คํธ/๋งํฌ๋ค์ด ๊ธ์ง. |
| |
| {{"food": "์์๋ช
", "confidence": 0.0~1.0, |
| "calories_kcal": ์ ์, "carbs_g": ์ ์, |
| "protein_g": ์ ์, "fat_g": ์ ์, |
| "note": "์ถ์ ๊ทผ๊ฑฐ ํ ์ค"}}" |
| |
| """ |
| |
| |
| |
| _vision_client: InferenceClient | None = None |
| _chain = None |
|
|
|
|
| def _vision_lazy() -> InferenceClient: |
| global _vision_client |
| if _vision_client is None: |
| _vision_client = InferenceClient(token=get_token()) |
| return _vision_client |
|
|
|
|
| def _chain_lazy(): |
| """LCEL ์ฒด์ธ: prompt | ChatHuggingFace | JsonOutputParser""" |
| global _chain |
| if _chain is None: |
| |
| endpoint = HuggingFaceEndpoint( |
| repo_id=LLM_MODEL, |
| task="text-generation", |
| max_new_tokens=300, |
| temperature=0.2, |
| huggingfacehub_api_token=get_token(), |
| ) |
|
|
| |
| llm = ChatHuggingFace(llm=endpoint) |
|
|
| |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", SYSTEM_PROMPT), |
| ("human", "๋ค์์ ์ด๋ฏธ์ง ๋ถ๋ฅ๊ธฐ์ top-k ๊ฒฐ๊ณผ๋ค:\n{labels_json}"), |
| ]) |
|
|
| |
| _chain = prompt | llm | JsonOutputParser() |
|
|
| return _chain |
|
|
|
|
| |
| |
| |
| def classify_food(image: Image.Image) -> list[dict[str, Any]]: |
| """HF ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ชจ๋ธ์ PIL ์ด๋ฏธ์ง๋ฅผ ๋๊ฒจ top-k ๊ฒฐ๊ณผ๋ฅผ ๋ฐ๋๋ค.""" |
| client = _vision_lazy() |
|
|
| |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmp: |
| image.convert("RGB").save(tmp, format="JPEG") |
| tmp_path = tmp.name |
|
|
| try: |
| raw = client.image_classification(tmp_path, model=VISION_MODEL) |
| finally: |
| os.unlink(tmp_path) |
|
|
| results: list[dict[str, Any]] = [] |
| for item in raw[:TOP_K]: |
| if isinstance(item, dict): |
| results.append({"label": item["label"], "score": float(item["score"])}) |
| else: |
| results.append({"label": item.label, "score": float(item.score)}) |
| return results |
|
|
|
|
| |
| |
| |
| def estimate_calories(labels: list[dict[str, Any]]) -> dict[str, Any]: |
| chain = _chain_lazy() |
| labels_json = json.dumps(labels, ensure_ascii=False) |
| try: |
| chain = _chain_lazy() |
| labels_json = json.dumps(labels, ensure_ascii=False) |
| return chain.invoke({"labels_json": labels_json}) |
| except Exception as e: |
| return { |
| "food": labels[0]["label"] if labels else "unknown", |
| "confidence": labels[0]["score"] if labels else 0.0, |
| "calories_kcal": 0, |
| "carbs_g": 0, |
| "protein_g": 0, |
| "fat_g": 0, |
| "note": f"์ฒด์ธ ์คํ ์คํจ: {type(e).__name__}: {str(e)[:120]}", |
| } |
|
|
|
|
| |
| |
| |
| def analyze(image): |
| if image is None: |
| return {}, {"error": "์ด๋ฏธ์ง๋ฅผ ๋จผ์ ์
๋ก๋ํด ์ฃผ์ธ์."} |
| labels = classify_food(image) |
| label_view = {item["label"]: item["score"] for item in labels} |
| nutrition = estimate_calories(labels) |
| return label_view, nutrition |
|
|
|
|
| |
| |
| |
| def build_ui() -> gr.Interface: |
| return gr.Interface( |
| fn=analyze, |
| inputs=gr.Image(type="pil", label="์์ ์ฌ์ง ์
๋ก๋"), |
| outputs=[ |
| gr.Label(num_top_classes=TOP_K, label="์์ ๋ถ๋ฅ ๊ฒฐ๊ณผ"), |
| gr.JSON(label="์นผ๋ก๋ฆฌ & ์์์ ์ถ์ "), |
| ], |
| title="๐ฑ HuggingFace Calorie Counter (LangChain LCEL)", |
| description=( |
| "์์ ์ฌ์ง์ ์
๋ก๋ํ๋ฉด HF Inference API๋ก ์์์ ์ธ์ํ๊ณ , " |
| "LangChain LCEL ์ฒด์ธ์ด 1์ธ๋ถ ๊ธฐ์ค ์นผ๋ก๋ฆฌ/์์์๋ฅผ ์ถ์ ํฉ๋๋ค. " |
| "๊ฒฐ๊ณผ๋ ์ฐธ๊ณ ์ฉ์
๋๋ค." |
| ), |
| flagging_mode="never", |
| ) |
|
|
|
|
| |
| demo = build_ui() |
|
|
| if __name__ == "__main__": |
| |
| |
| is_space = bool(os.getenv("SPACE_ID")) |
| demo.launch( |
| server_name="0.0.0.0" if is_space else "127.0.0.1", |
| server_port=int(os.getenv("PORT", 7860)), |
| show_api=False, |
| ssr_mode=False, |
| ) |
|
|