Spaces:
Running
Running
| """Hugging Face Gradio Space: Command A+ multimodal chat demo.""" | |
| from __future__ import annotations | |
| import base64 | |
| import logging | |
| import mimetypes | |
| import os | |
| import re | |
| from collections.abc import Iterator | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from cohere import ClientV2 | |
| from cohere.core.api_error import ApiError | |
| APP_ROOT = Path(__file__).resolve().parent | |
| logger = logging.getLogger(__name__) | |
| APP_TITLE = "Command A+" | |
| CLIENT_NAME = "hf-command-a-plus-05-2026" | |
| DEFAULT_MODEL_ID = "command-a-plus-05-2026" | |
| DEFAULT_TEMPERATURE = 0.2 | |
| MODEL_URL = "https://huggingface.co/CohereLabs/command-a-plus-05-2026-w4a4" | |
| PRIVACY_URL = "https://cohere.com/privacy" | |
| IMAGE_DETAIL = "auto" | |
| MAX_IMAGES_PER_REQUEST = 20 | |
| MAX_TOTAL_IMAGE_BYTES = 20 * 1024 * 1024 | |
| MAX_TOTAL_IMAGE_LABEL = f"{MAX_TOTAL_IMAGE_BYTES // (1024 * 1024)} MB" | |
| IMAGE_MIME_TYPES = {"image/gif", "image/jpeg", "image/png", "image/webp"} | |
| THINKING_BLOCK_RE = re.compile(r"<\s*think\s*>.*?<\s*/\s*think\s*>", re.IGNORECASE | re.DOTALL) | |
| INVOICE_IMAGE = str(APP_ROOT / "img" / "invoice-1.jpg") | |
| MODEL_ID = os.getenv("COMMAND_A_PLUS_MODEL_ID", DEFAULT_MODEL_ID).strip() or DEFAULT_MODEL_ID | |
| API_KEY = os.getenv("COHERE_API_KEY", "").strip() | |
| APP_THEME = gr.themes.Soft( | |
| primary_hue="stone", | |
| secondary_hue="green", | |
| neutral_hue="stone", | |
| font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], | |
| ).set( | |
| body_background_fill="#ffffff", | |
| body_background_fill_dark="#07110f", | |
| body_text_color="#212121", | |
| body_text_color_dark="#f7f5ef", | |
| body_text_color_subdued="#75758a", | |
| body_text_color_subdued_dark="#b9b8ad", | |
| block_background_fill="#ffffff", | |
| block_background_fill_dark="#0d1714", | |
| block_border_color="#d9d9dd", | |
| block_border_color_dark="rgba(238, 236, 231, 0.22)", | |
| block_label_text_color="#17171c", | |
| block_label_text_color_dark="#f7f5ef", | |
| input_background_fill="#ffffff", | |
| input_background_fill_dark="#07110f", | |
| input_border_color="#d9d9dd", | |
| input_border_color_dark="rgba(238, 236, 231, 0.28)", | |
| button_primary_background_fill="#17171c", | |
| button_primary_background_fill_dark="#f7f5ef", | |
| button_primary_background_fill_hover="#003c33", | |
| button_primary_background_fill_hover_dark="#edfce9", | |
| button_primary_text_color="#ffffff", | |
| button_primary_text_color_dark="#07110f", | |
| link_text_color="#003c33", | |
| link_text_color_dark="#7fd3b0", | |
| ) | |
| def _build_client() -> ClientV2 | None: | |
| if API_KEY: | |
| return ClientV2(api_key=API_KEY, client_name=CLIENT_NAME) | |
| logger.warning("COHERE_API_KEY is not set; inference is disabled until configured.") | |
| return None | |
| CLIENT = _build_client() | |
| def _extract_content_parts(content: object) -> tuple[str, str]: | |
| """Extract visible text and thinking text from Cohere content shapes.""" | |
| if content is None: | |
| return "", "" | |
| if isinstance(content, str): | |
| return content, "" | |
| if isinstance(content, list): | |
| parts = [_extract_content_parts(block) for block in content] | |
| return "".join(text for text, _ in parts), "".join(thinking for _, thinking in parts) | |
| if isinstance(content, dict): | |
| text = str(content.get("text") or "") | |
| thinking = str(content.get("thinking") or "") | |
| if not text and not thinking and "content" in content: | |
| return _extract_content_parts(content.get("content")) | |
| return text, thinking | |
| text = getattr(content, "text", None) | |
| thinking = getattr(content, "thinking", None) | |
| return (str(text) if text is not None else ""), (str(thinking) if thinking is not None else "") | |
| def _extract_text(content: object) -> str: | |
| return _extract_content_parts(content)[0] | |
| def _strip_thinking_blocks(text: str) -> str: | |
| return THINKING_BLOCK_RE.sub("", text).strip() | |
| def _format_response(output: str, thinking: str) -> str: | |
| thinking = thinking.strip() | |
| if not thinking: | |
| return output | |
| if not output: | |
| return f"<think>{thinking}</think>" | |
| return f"<think>{thinking}</think>\n\n{output}" | |
| def _file_path_or_url(file_value: object) -> str | None: | |
| if isinstance(file_value, str): | |
| return file_value | |
| if isinstance(file_value, dict): | |
| raw_value = file_value.get("path") or file_value.get("name") or file_value.get("url") | |
| return str(raw_value) if raw_value else None | |
| path = getattr(file_value, "path", None) | |
| return str(path) if path else None | |
| def _guess_mime_type(path_or_url: str, file_value: object) -> str: | |
| guess_from = path_or_url | |
| if isinstance(file_value, dict): | |
| guess_from = str( | |
| file_value.get("orig_name") or file_value.get("name") or path_or_url | |
| ) | |
| return mimetypes.guess_type(guess_from)[0] or "image/png" | |
| def _data_url_decoded_size(url: str) -> int: | |
| """Best-effort size estimate for a `data:` URL payload (base64 or percent-encoded).""" | |
| _, _, payload = url.partition(",") | |
| if not payload: | |
| return 0 | |
| head = url.split(",", 1)[0] | |
| if ";base64" in head: | |
| padding = payload.count("=") | |
| return max(0, (len(payload) * 3) // 4 - padding) | |
| return len(payload) | |
| def _text_block(text: str) -> dict[str, Any]: | |
| return {"type": "text", "text": text} | |
| def _message_files(message: dict[str, Any]) -> list[object]: | |
| files = message.get("files") or [] | |
| return files if isinstance(files, list) else [files] | |
| class _ImageBudget: | |
| """Enforce the Cohere API per-request image count and total-byte limits.""" | |
| def __init__(self) -> None: | |
| self.count = 0 | |
| self.bytes = 0 | |
| def add(self, size: int) -> None: | |
| self.count += 1 | |
| if self.count > MAX_IMAGES_PER_REQUEST: | |
| raise gr.Error( | |
| f"This conversation exceeds the {MAX_IMAGES_PER_REQUEST}-image limit per request. " | |
| "Start a new chat or remove some images." | |
| ) | |
| self.bytes += max(0, size) | |
| if self.bytes > MAX_TOTAL_IMAGE_BYTES: | |
| raise gr.Error( | |
| f"Total image data exceeds {MAX_TOTAL_IMAGE_LABEL} per request. " | |
| "Use smaller images or fewer attachments." | |
| ) | |
| def _image_block_from_file( | |
| file_value: object, | |
| budget: _ImageBudget, | |
| *, | |
| required: bool, | |
| ) -> dict[str, Any] | None: | |
| """Convert a Gradio file value into Cohere image_url content.""" | |
| path_or_url = _file_path_or_url(file_value) | |
| if not path_or_url: | |
| if required: | |
| raise gr.Error("The uploaded image could not be read. Try uploading again.") | |
| return None | |
| if path_or_url.startswith(("http://", "https://")): | |
| # Remote URLs: size is unknown client-side; count toward image cap only. | |
| budget.add(0) | |
| return { | |
| "type": "image_url", | |
| "image_url": {"url": path_or_url, "detail": IMAGE_DETAIL}, | |
| } | |
| if path_or_url.startswith("data:"): | |
| budget.add(_data_url_decoded_size(path_or_url)) | |
| return { | |
| "type": "image_url", | |
| "image_url": {"url": path_or_url, "detail": IMAGE_DETAIL}, | |
| } | |
| path = Path(path_or_url) | |
| if not path.is_file(): | |
| if required: | |
| raise gr.Error("The uploaded image could not be read. Try uploading again.") | |
| return None | |
| mime_type = _guess_mime_type(path_or_url, file_value) | |
| if mime_type not in IMAGE_MIME_TYPES: | |
| raise gr.Error( | |
| "Unsupported attachment. Use PNG, JPEG, WEBP, or non-animated GIF." | |
| ) | |
| budget.add(path.stat().st_size) | |
| raw = path.read_bytes() | |
| b64 = base64.standard_b64encode(raw).decode("ascii") | |
| return { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:{mime_type};base64,{b64}", | |
| "detail": IMAGE_DETAIL, | |
| }, | |
| } | |
| def _blocks_from_user_message( | |
| message: dict[str, Any] | None, | |
| budget: _ImageBudget, | |
| *, | |
| required_files: bool, | |
| ) -> list[dict[str, Any]]: | |
| if not message: | |
| return [] | |
| blocks: list[dict[str, Any]] = [] | |
| text = str(message.get("text") or "").strip() | |
| if text: | |
| blocks.append(_text_block(text)) | |
| files = _message_files(message) | |
| for file_value in files: | |
| image_block = _image_block_from_file(file_value, budget, required=required_files) | |
| if image_block: | |
| blocks.append(image_block) | |
| if not text and files: | |
| blocks.insert(0, _text_block("Please analyze the attached image(s).")) | |
| return blocks | |
| def _blocks_from_history_content(content: object, budget: _ImageBudget) -> list[dict[str, Any]]: | |
| if isinstance(content, str): | |
| text = _strip_thinking_blocks(content) | |
| return [_text_block(text)] if text else [] | |
| if isinstance(content, list): | |
| blocks: list[dict[str, Any]] = [] | |
| for item in content: | |
| blocks.extend(_blocks_from_history_content(item, budget)) | |
| return blocks | |
| if isinstance(content, dict): | |
| if content.get("path") or content.get("name") or content.get("url"): | |
| image_block = _image_block_from_file(content, budget, required=False) | |
| return [image_block] if image_block else [] | |
| text = _strip_thinking_blocks(_extract_text(content)) | |
| return [_text_block(text)] if text else [] | |
| text = _strip_thinking_blocks(_extract_text(content)) | |
| return [_text_block(text)] if text else [] | |
| def _cohere_content_from_blocks(blocks: list[dict[str, Any]]) -> str | list[dict[str, Any]]: | |
| if len(blocks) == 1 and blocks[0].get("type") == "text": | |
| return str(blocks[0].get("text") or "") | |
| return blocks | |
| def _assistant_text_from_blocks(blocks: list[dict[str, Any]]) -> str: | |
| return "".join( | |
| str(block.get("text") or "") | |
| for block in blocks | |
| if block.get("type") == "text" | |
| ).strip() | |
| def _append_history_messages( | |
| messages: list[dict[str, Any]], | |
| history: list[dict[str, Any]] | None, | |
| budget: _ImageBudget, | |
| ) -> None: | |
| for item in history or []: | |
| role = item.get("role") if isinstance(item, dict) else None | |
| if role not in {"assistant", "user"}: | |
| continue | |
| blocks = _blocks_from_history_content(item.get("content"), budget) | |
| if not blocks: | |
| continue | |
| if role == "assistant": | |
| text = _assistant_text_from_blocks(blocks) | |
| if text: | |
| messages.append({"role": "assistant", "content": text}) | |
| else: | |
| messages.append({"role": "user", "content": _cohere_content_from_blocks(blocks)}) | |
| def _no_output_note(finish_reason: str) -> str: | |
| """Friendly message when the stream ended without emitting any visible text.""" | |
| if finish_reason == "MAX_TOKENS": | |
| return ( | |
| "_The model hit its native output-token cap before producing a final " | |
| "answer (generated reasoning only). Try a shorter or simpler prompt._" | |
| ) | |
| if finish_reason == "ERROR": | |
| return "_The model returned an error before producing an answer. Please try again._" | |
| if finish_reason == "STOP_SEQUENCE": | |
| return "_The model stopped at a stop sequence before producing visible text._" | |
| return ( | |
| f"_The model finished without producing a visible response " | |
| f"(finish_reason={finish_reason}). Please try again or rephrase._" | |
| ) | |
| def _format_api_error(exc: ApiError) -> str: | |
| """Turn a Cohere ApiError into a short, user-readable diagnostic.""" | |
| body = exc.body | |
| if isinstance(body, dict): | |
| message = body.get("message") or body.get("error") or "" | |
| body_text = str(message) if message else str(body) | |
| else: | |
| body_text = str(body or "").strip() | |
| if exc.status_code == 404 and "page not found" in body_text.lower(): | |
| return ( | |
| f"Model `{MODEL_ID}` was not found on the Cohere API. " | |
| "Check the model id or set the `COMMAND_A_PLUS_MODEL_ID` env var." | |
| ) | |
| if exc.status_code in (401, 403): | |
| return "Your `COHERE_API_KEY` was rejected. Check the secret in Space settings." | |
| if exc.status_code == 429: | |
| return "Rate-limited by the Cohere API. Please wait and try again." | |
| return body_text[:240] or f"HTTP {exc.status_code}" | |
| def respond( | |
| message: dict[str, Any] | None, | |
| history: list[dict[str, Any]], | |
| ) -> Iterator[str]: | |
| """Stream assistant text for a multimodal chat turn.""" | |
| if CLIENT is None: | |
| yield ( | |
| "This Space needs a `COHERE_API_KEY` secret to call the Cohere API. " | |
| "Add it in Space settings, then refresh the page." | |
| ) | |
| return | |
| client = CLIENT | |
| messages: list[dict[str, Any]] = [] | |
| budget = _ImageBudget() | |
| _append_history_messages(messages, history, budget) | |
| try: | |
| current_blocks = _blocks_from_user_message(message, budget, required_files=True) | |
| except OSError as exc: | |
| logger.exception("Failed to read image") | |
| raise gr.Error("Could not read the image file.") from exc | |
| if not current_blocks: | |
| yield "Send a message or attach an image to start the conversation." | |
| return | |
| messages.append({"role": "user", "content": _cohere_content_from_blocks(current_blocks)}) | |
| output = "" | |
| thinking_output = "" | |
| finish_reason: str | None = None | |
| event_counts: dict[str, int] = {} | |
| try: | |
| stream = client.chat_stream( | |
| model=MODEL_ID, | |
| messages=messages, | |
| temperature=DEFAULT_TEMPERATURE, | |
| thinking={"type": "enabled"}, | |
| ) | |
| for event in stream: | |
| event_type = getattr(event, "type", None) or "unknown" | |
| event_counts[event_type] = event_counts.get(event_type, 0) + 1 | |
| delta = getattr(event, "delta", None) | |
| if event_type in ("content-delta", "content-start"): | |
| msg = getattr(delta, "message", None) if delta is not None else None | |
| if msg is None: | |
| continue | |
| text, thinking = _extract_content_parts(getattr(msg, "content", None)) | |
| if thinking: | |
| thinking_output += thinking | |
| yield _format_response(output, thinking_output) | |
| if text: | |
| output += text | |
| yield _format_response(output, thinking_output) | |
| elif event_type == "message-end": | |
| # delta carries finish_reason and (sometimes) usage info. | |
| finish_reason = getattr(delta, "finish_reason", None) | |
| if finish_reason is None and isinstance(delta, dict): | |
| finish_reason = delta.get("finish_reason") | |
| logger.info( | |
| "Cohere stream ended: finish_reason=%s, output_len=%d, thinking_len=%d, events=%s", | |
| finish_reason, len(output), len(thinking_output), event_counts, | |
| ) | |
| if not output: | |
| reason_text = (finish_reason or "unknown").upper() | |
| logger.warning( | |
| "Stream produced no visible text. finish_reason=%s, thinking_len=%d, events=%s", | |
| reason_text, len(thinking_output), event_counts, | |
| ) | |
| note = _no_output_note(reason_text) | |
| yield _format_response(note, thinking_output) | |
| except ApiError as exc: | |
| logger.exception("Cohere API error (status=%s)", exc.status_code) | |
| detail = _format_api_error(exc) | |
| gr.Warning(f"Cohere API error ({exc.status_code}). {detail}") | |
| yield _format_response(output + f"\n\n_Cohere API error_: {detail}", thinking_output) | |
| except Exception as exc: | |
| logger.exception("Unexpected error calling Cohere API") | |
| gr.Warning(f"Unexpected error: {exc}") | |
| yield _format_response(output + f"\n\n_Unexpected error_: {exc}", thinking_output) | |
| def _example_message(text: str, files: list[str] | None = None) -> dict[str, Any]: | |
| return {"text": text, "files": files or []} | |
| def build_examples() -> tuple[list[dict[str, Any]], list[str]]: | |
| """Chat starter prompts. Mixes multimodal, reasoning, multilingual, and code tasks.""" | |
| examples = [ | |
| _example_message( | |
| "What is the total amount of the invoice with and without tax?", | |
| files=[INVOICE_IMAGE], | |
| ), | |
| _example_message( | |
| "Extract every line item from this invoice as a JSON array with " | |
| "description, quantity, unit price, and amount.", | |
| files=[INVOICE_IMAGE], | |
| ), | |
| _example_message( | |
| "```\nX +\n *\n```\n\n" | |
| "Reason about the above scene depicted in the markdown code block. " | |
| "If I interchange the locations of * and X, and then I interchange the " | |
| "locations of * and +, and then I flip the image like a left-right mirror, " | |
| "which symbol is on the leftmost part of the image?" | |
| ), | |
| _example_message( | |
| "You are running a race and overtake the person at position 76487423. " | |
| "What place are you in now?" | |
| ), | |
| _example_message( | |
| "Twenty-four red socks and 24 blue socks are lying in a drawer in a dark " | |
| "room. What is the minimum number of socks I must take out of the drawer " | |
| "which will guarantee that I have at least two socks of the same color?" | |
| ), | |
| _example_message("Explique la théorie de la relativité en français."), | |
| ] | |
| labels = [ | |
| "Invoice: totals", | |
| "Invoice: line items", | |
| "Symbol reasoning", | |
| "Overtaking puzzle", | |
| "Socks in the dark", | |
| "Relativité en français", | |
| ] | |
| return examples, labels | |
| EXAMPLE_ROWS, EXAMPLE_LABELS = build_examples() | |
| def build_hero_markdown() -> str: | |
| return f""" | |
| <section class="hero"> | |
| <div class="hero-grid"> | |
| <div> | |
| <h1>{APP_TITLE}</h1> | |
| </div> | |
| </div> | |
| <p class="compact-note">Model: <a href="{MODEL_URL}" target="_blank" rel="noopener noreferrer"><code>{MODEL_ID}</code></a> · Up to <code>{MAX_IMAGES_PER_REQUEST}</code> images or <code>{MAX_TOTAL_IMAGE_LABEL}</code> total per request (PNG, JPEG, WEBP, non-animated GIF) · By using this Space you agree to the | |
| <a href="{PRIVACY_URL}" target="_blank" rel="noopener noreferrer">Cohere Privacy Policy</a>. Images are sent to the Cohere API for processing.</p> | |
| </section> | |
| """ | |
| def build_placeholder_html() -> str: | |
| return f""" | |
| <div class="chat-placeholder"> | |
| <div class="placeholder-kicker">{APP_TITLE}</div> | |
| <strong>Ask about anything.</strong> | |
| <span>Drop a document, chart, or photo and start the conversation.</span> | |
| </div> | |
| """ | |
| def build_configuration_banner() -> str: | |
| return ( | |
| '<div class="status-banner"><strong>Configuration required.</strong> ' | |
| "Set the <code>COHERE_API_KEY</code> secret in Space settings to enable generation.</div>" | |
| ) | |
| def build_demo() -> gr.Blocks: | |
| with gr.Blocks(title=APP_TITLE, fill_height=True) as demo: | |
| with gr.Column(elem_classes="app-shell"): | |
| gr.Markdown(build_hero_markdown(), sanitize_html=False) | |
| if CLIENT is None: | |
| gr.Markdown(build_configuration_banner(), sanitize_html=False) | |
| chatbot = gr.Chatbot( | |
| show_label=False, | |
| layout="bubble", | |
| min_height=520, | |
| height="62vh", | |
| placeholder=build_placeholder_html(), | |
| reasoning_tags=[("<think>", "</think>")], | |
| elem_classes=["command-chatbot"], | |
| latex_delimiters=[ | |
| {"left": "$$", "right": "$$", "display": True}, | |
| {"left": "\\[", "right": "\\]", "display": True}, | |
| {"left": "\\(", "right": "\\)", "display": False}, | |
| ], | |
| ) | |
| textbox = gr.MultimodalTextbox( | |
| file_types=["image"], | |
| file_count="multiple", | |
| sources=["upload"], | |
| placeholder="Message Command A+ or attach images...", | |
| lines=1, | |
| max_lines=6, | |
| show_label=False, | |
| container=False, | |
| submit_btn=True, | |
| stop_btn=True, | |
| elem_classes=["command-input"], | |
| ) | |
| gr.ChatInterface( | |
| fn=respond, | |
| multimodal=True, | |
| chatbot=chatbot, | |
| textbox=textbox, | |
| examples=EXAMPLE_ROWS, | |
| example_labels=EXAMPLE_LABELS, | |
| run_examples_on_click=True, | |
| cache_examples=False, | |
| delete_cache=(1800, 1800), | |
| save_history=True, | |
| stop_btn=True, | |
| fill_width=True, | |
| show_progress="minimal", | |
| ) | |
| return demo | |
| demo = build_demo() | |
| demo.queue(default_concurrency_limit=2) | |
| if __name__ == "__main__": | |
| demo.launch(theme=APP_THEME, css_paths="style.css") | |