Spaces:
Sleeping
Sleeping
| import ast | |
| import difflib | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import time | |
| from dotenv import load_dotenv | |
| from chart_generator import ChartGenerator | |
| from data_processor import DataProcessor | |
| load_dotenv() | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Model IDs (downloaded at Docker build, cached in HF_HOME) | |
| # --------------------------------------------------------------------------- | |
| QWEN_MODEL_ID = os.getenv("QWEN_MODEL_ID", "Qwen/Qwen2.5-Coder-0.5B-Instruct") | |
| BART_MODEL_ID = os.getenv("BART_MODEL_ID", "ArchCoder/fine-tuned-bart-large") | |
| # --------------------------------------------------------------------------- | |
| # Prompt templates with few-shot examples | |
| # --------------------------------------------------------------------------- | |
| _SYSTEM_PROMPT = """\ | |
| You are a data visualization expert. Given the user request and dataset schema, \ | |
| output ONLY a valid JSON object. No explanation, no markdown fences, no extra text. | |
| Required JSON keys: | |
| "x" : string β exact column name for the x-axis | |
| "y" : array β one or more exact column names for the y-axis | |
| "chart_type" : string β one of: line, bar, scatter, pie, histogram, box, area | |
| "color" : string or null β optional CSS color like "red", "#4f8cff" | |
| Rules: | |
| - Use ONLY column names from the schema. Never invent names. | |
| - For pie charts: y must contain exactly one column. | |
| - For histogram/box: x may equal the first element of y. | |
| - Default to "line" if chart type is ambiguous. | |
| ### Examples | |
| Example 1: | |
| Schema: Year (integer), Sales (float), Profit (float) | |
| User: "plot sales over the years with a red line" | |
| Output: {"x": "Year", "y": ["Sales"], "chart_type": "line", "color": "red"} | |
| Example 2: | |
| Schema: Month (string), Revenue (float), Expenses (float) | |
| User: "bar chart comparing revenue and expenses by month" | |
| Output: {"x": "Month", "y": ["Revenue", "Expenses"], "chart_type": "bar", "color": null} | |
| Example 3: | |
| Schema: Category (string), Count (integer) | |
| User: "pie chart of count by category" | |
| Output: {"x": "Category", "y": ["Count"], "chart_type": "pie", "color": null} | |
| Example 4: | |
| Schema: Date (string), Temperature (float), Humidity (float) | |
| User: "scatter plot of temperature vs humidity in blue" | |
| Output: {"x": "Temperature", "y": ["Humidity"], "chart_type": "scatter", "color": "blue"} | |
| Example 5: | |
| Schema: Year (integer), Sales (float), Employee expense (float), Marketing expense (float) | |
| User: "show me an area chart of sales and marketing expense over years" | |
| Output: {"x": "Year", "y": ["Sales", "Marketing expense"], "chart_type": "area", "color": null} | |
| """ | |
| def _user_message(query: str, columns: list, dtypes: dict, sample_rows: list) -> str: | |
| schema = "\n".join(f" - {c} ({dtypes.get(c, 'unknown')})" for c in columns) | |
| samples = "".join(f" {json.dumps(r)}\n" for r in sample_rows[:3]) | |
| return ( | |
| f"Schema:\n{schema}\n\n" | |
| f"Sample rows:\n{samples}\n" | |
| f"User: \"{query}\"\n" | |
| f"Output:" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Output parsing & validation | |
| # --------------------------------------------------------------------------- | |
| def _parse_output(text: str): | |
| text = text.strip() | |
| if "```" in text: | |
| for part in text.split("```"): | |
| part = part.strip().lstrip("json").strip() | |
| if part.startswith("{"): | |
| text = part | |
| break | |
| try: | |
| return json.loads(text) | |
| except json.JSONDecodeError: | |
| pass | |
| try: | |
| return ast.literal_eval(text) | |
| except (SyntaxError, ValueError): | |
| pass | |
| return None | |
| def _validate(args: dict, columns: list): | |
| if not isinstance(args, dict): | |
| return None | |
| if not all(k in args for k in ("x", "y", "chart_type")): | |
| return None | |
| if isinstance(args["y"], str): | |
| args["y"] = [args["y"]] | |
| valid = {"line", "bar", "scatter", "pie", "histogram", "box", "area"} | |
| if args["chart_type"] not in valid: | |
| args["chart_type"] = "line" | |
| if args["x"] not in columns: | |
| return None | |
| if not all(c in columns for c in args["y"]): | |
| return None | |
| return args | |
| def _pick_chart_type(query: str) -> str: | |
| lowered = query.lower() | |
| aliases = { | |
| "scatter": ["scatter", "scatterplot"], | |
| "bar": ["bar", "column"], | |
| "pie": ["pie", "donut"], | |
| "histogram": ["histogram", "distribution"], | |
| "box": ["box", "boxplot"], | |
| "area": ["area"], | |
| "line": ["line", "trend", "over time", "over the years"], | |
| } | |
| for chart_type, keywords in aliases.items(): | |
| if any(keyword in lowered for keyword in keywords): | |
| return chart_type | |
| return "line" | |
| def _pick_color(query: str): | |
| lowered = query.lower() | |
| colors = [ | |
| "red", "blue", "green", "yellow", "orange", "purple", "pink", | |
| "black", "white", "gray", "grey", "cyan", "teal", "indigo", | |
| ] | |
| for color in colors: | |
| if re.search(rf"\b{re.escape(color)}\b", lowered): | |
| return color | |
| return None | |
| def _pick_columns(query: str, columns: list, dtypes: dict): | |
| lowered = query.lower() | |
| query_tokens = re.findall(r"[a-zA-Z0-9_]+", lowered) | |
| def score_column(column: str) -> float: | |
| col_lower = column.lower() | |
| score = 0.0 | |
| if col_lower in lowered: | |
| score += 10.0 | |
| for token in query_tokens: | |
| if token and token in col_lower: | |
| score += 2.0 | |
| score += difflib.SequenceMatcher(None, lowered, col_lower).ratio() | |
| return score | |
| sorted_columns = sorted(columns, key=score_column, reverse=True) | |
| numeric_columns = [col for col in columns if dtypes.get(col) in {"integer", "float"}] | |
| temporal_columns = [col for col in columns if dtypes.get(col) == "datetime"] | |
| year_like = [col for col in columns if "year" in col.lower() or "date" in col.lower() or "month" in col.lower()] | |
| x_col = None | |
| for candidate in year_like + temporal_columns + sorted_columns: | |
| if candidate in columns: | |
| x_col = candidate | |
| break | |
| if x_col is None and columns: | |
| x_col = columns[0] | |
| y_candidates = [col for col in sorted_columns if col != x_col and col in numeric_columns] | |
| if not y_candidates: | |
| y_candidates = [col for col in numeric_columns if col != x_col] | |
| if not y_candidates: | |
| y_candidates = [col for col in columns if col != x_col] | |
| return x_col, y_candidates[:1] | |
| def _heuristic_plot_args(query: str, columns: list, dtypes: dict) -> dict: | |
| x_col, y_cols = _pick_columns(query, columns, dtypes) | |
| if not x_col: | |
| x_col = "Year" | |
| if not y_cols: | |
| fallback_y = next((col for col in columns if col != x_col), columns[:1]) | |
| y_cols = list(fallback_y) if isinstance(fallback_y, tuple) else fallback_y | |
| if isinstance(y_cols, str): | |
| y_cols = [y_cols] | |
| return { | |
| "x": x_col, | |
| "y": y_cols, | |
| "chart_type": _pick_chart_type(query), | |
| "color": _pick_color(query), | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Agent | |
| # --------------------------------------------------------------------------- | |
| class LLM_Agent: | |
| def __init__(self, data_path=None): | |
| logger.info("Initializing LLM_Agent") | |
| self.data_processor = DataProcessor(data_path) | |
| self.chart_generator = ChartGenerator(self.data_processor.data) | |
| self._bart_tokenizer = None | |
| self._bart_model = None | |
| self._qwen_tokenizer = None | |
| self._qwen_model = None | |
| # -- model runners ------------------------------------------------------- | |
| def _run_qwen(self, user_msg: str) -> str: | |
| """Qwen2.5-Coder-0.5B-Instruct β fast structured-JSON generation.""" | |
| if self._qwen_model is None: | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info(f"Loading Qwen model: {QWEN_MODEL_ID}") | |
| self._qwen_tokenizer = AutoTokenizer.from_pretrained(QWEN_MODEL_ID) | |
| self._qwen_model = AutoModelForCausalLM.from_pretrained(QWEN_MODEL_ID) | |
| logger.info("Qwen model loaded.") | |
| messages = [ | |
| {"role": "system", "content": _SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ] | |
| text = self._qwen_tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = self._qwen_tokenizer(text, return_tensors="pt") | |
| outputs = self._qwen_model.generate( | |
| **inputs, max_new_tokens=256, temperature=0.1, do_sample=True | |
| ) | |
| return self._qwen_tokenizer.decode( | |
| outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True | |
| ) | |
| def _run_gemini(self, user_msg: str) -> str: | |
| import google.generativeai as genai | |
| api_key = os.getenv("GEMINI_API_KEY") | |
| if not api_key: | |
| raise ValueError("GEMINI_API_KEY is not set") | |
| genai.configure(api_key=api_key) | |
| model = genai.GenerativeModel( | |
| "gemini-2.0-flash", | |
| system_instruction=_SYSTEM_PROMPT, | |
| ) | |
| return model.generate_content(user_msg).text | |
| def _run_grok(self, user_msg: str) -> str: | |
| from openai import OpenAI | |
| api_key = os.getenv("GROK_API_KEY") | |
| if not api_key: | |
| raise ValueError("GROK_API_KEY is not set") | |
| client = OpenAI(api_key=api_key, base_url="https://api.x.ai/v1") | |
| resp = client.chat.completions.create( | |
| model="grok-3-mini", | |
| messages=[ | |
| {"role": "system", "content": _SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| max_tokens=256, | |
| temperature=0.1, | |
| ) | |
| return resp.choices[0].message.content | |
| def _run_bart(self, query: str) -> str: | |
| """ArchCoder/fine-tuned-bart-large β lightweight Seq2Seq fallback.""" | |
| if self._bart_model is None: | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| logger.info(f"Loading BART model: {BART_MODEL_ID}") | |
| self._bart_tokenizer = AutoTokenizer.from_pretrained(BART_MODEL_ID) | |
| self._bart_model = AutoModelForSeq2SeqLM.from_pretrained(BART_MODEL_ID) | |
| logger.info("BART model loaded.") | |
| inputs = self._bart_tokenizer( | |
| query, return_tensors="pt", max_length=512, truncation=True | |
| ) | |
| outputs = self._bart_model.generate(**inputs, max_length=100) | |
| return self._bart_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # -- main entry point ---------------------------------------------------- | |
| def process_request(self, data: dict) -> dict: | |
| t0 = time.time() | |
| query = data.get("query", "") | |
| data_path = data.get("file_path") | |
| model = data.get("model", "qwen") | |
| if data_path and os.path.exists(data_path): | |
| self.data_processor = DataProcessor(data_path) | |
| self.chart_generator = ChartGenerator(self.data_processor.data) | |
| columns = self.data_processor.get_columns() | |
| dtypes = self.data_processor.get_dtypes() | |
| sample_rows = self.data_processor.preview(3) | |
| default_args = { | |
| "x": columns[0] if columns else "Year", | |
| "y": [columns[1]] if len(columns) > 1 else ["Sales"], | |
| "chart_type": "line", | |
| } | |
| raw_text = "" | |
| plot_args = None | |
| try: | |
| user_msg = _user_message(query, columns, dtypes, sample_rows) | |
| if model == "gemini": raw_text = self._run_gemini(user_msg) | |
| elif model == "grok": raw_text = self._run_grok(user_msg) | |
| elif model == "bart": raw_text = self._run_bart(query) | |
| elif model == "qwen": | |
| try: | |
| raw_text = self._run_qwen(user_msg) | |
| except Exception as qwen_exc: | |
| logger.warning(f"Qwen failed, falling back to BART: {qwen_exc}") | |
| raw_text = self._run_bart(query) | |
| else: | |
| raw_text = self._run_qwen(user_msg) | |
| logger.info(f"LLM [{model}] output: {raw_text}") | |
| parsed = _parse_output(raw_text) | |
| plot_args = _validate(parsed, columns) if parsed else None | |
| except Exception as exc: | |
| logger.error(f"LLM error [{model}]: {exc}") | |
| raw_text = str(exc) | |
| if not plot_args: | |
| logger.warning("Falling back to heuristic plot args") | |
| plot_args = _validate(_heuristic_plot_args(query, columns, dtypes), columns) or default_args | |
| try: | |
| chart_result = self.chart_generator.generate_chart(plot_args) | |
| chart_path = chart_result["chart_path"] | |
| chart_spec = chart_result["chart_spec"] | |
| except Exception as exc: | |
| logger.error(f"Chart generation error: {exc}") | |
| return { | |
| "response": f"Chart generation failed: {exc}", | |
| "chart_path": "", | |
| "chart_spec": None, | |
| "verified": False, | |
| "plot_args": plot_args, | |
| } | |
| logger.info(f"Request processed in {time.time() - t0:.2f}s") | |
| return { | |
| "response": json.dumps(plot_args), | |
| "chart_path": chart_path, | |
| "chart_spec": chart_spec, | |
| "verified": True, | |
| "plot_args": plot_args, | |
| } | |