| |
| |
| |
|
|
| import os |
| import re |
| import json |
| from datetime import datetime |
| from typing import List, Dict, Any, Optional, Literal |
|
|
| from fastapi import FastAPI, Request, BackgroundTasks |
| from fastapi.middleware.cors import CORSMiddleware |
| import gradio as gr |
| import uvicorn |
| from pydantic import BaseModel |
| from fastapi.responses import RedirectResponse |
| from huggingface_hub.inference._mcp.agent import Agent |
| from dotenv import load_dotenv |
| load_dotenv() |
|
|
| HF_TOKEN=os.getenv("HF_TOKEN") |
| WEBHOOK_SECRET=os.getenv("WEBHOOK_SECRET") |
| HF_MODEL=os.getenv("HF_MODEL","HuggingFaceH4/zephyr-7b-beta") |
| DEFAULT_PROVIDER:Literal['hf-inference']="hf-inference" |
| HF_PROVIDER=os.getenv("HF_PROVIDER") |
| agent_instance: Optional[Agent]=None |
| tag_operations_store:List[Dict[str,Any]]=[] |
|
|
| RECOGNIZED_TAGS = { |
| "pytorch", |
| "tensorflow", |
| "jax", |
| "transformers", |
| "diffusers", |
| "text-generation", |
| "text-classification", |
| "question-answering", |
| "text-to-image", |
| "image-classification", |
| "object-detection", |
| " ", |
| "fill-mask", |
| "token-classification", |
| "translation", |
| "summarization", |
| "feature-extraction", |
| "sentence-similarity", |
| "zero-shot-classification", |
| "image-to-text", |
| "automatic-speech-recognition", |
| "audio-classification", |
| "voice-activity-detection", |
| "depth-estimation", |
| "image-segmentation", |
| "video-classification", |
| "reinforcement-learning", |
| "tabular-classification", |
| "tabular-regression", |
| "time-series-forecasting", |
| "graph-ml", |
| "robotics", |
| "computer-vision", |
| "nlp", |
| "cv", |
| "multimodal", |
| } |
|
|
|
|
| class WebhookEvent(BaseModel): |
| event: Dict[str, str] |
| comment: Dict[str, Any] |
| discussion: Dict[str, Any] |
| repo: Dict[str, str] |
|
|
|
|
| app = FastAPI(title="HF Tagging Bot") |
| app.add_middleware(CORSMiddleware, allow_origins=["*"]) |
|
|
|
|
| async def get_agent(): |
| """Get or create Agent instance""" |
| print("π€ get_agent() called...") |
| global agent_instance |
| if agent_instance is None and HF_TOKEN: |
| print("π§ Creating new Agent instance...") |
| print(f"π HF_TOKEN present: {bool(HF_TOKEN)}") |
| print(f"π€ Model: {HF_MODEL}") |
| print(f"π Provider: {DEFAULT_PROVIDER}") |
|
|
| try: |
| agent_instance = Agent( |
| model=HF_MODEL, |
| provider=DEFAULT_PROVIDER, |
| api_key=HF_TOKEN, |
| servers=[ |
| { |
| "type": "stdio", |
| "command": "python", |
| "args": ["mcp_server.py"], |
| "cwd": ".", |
| "env": {"HF_TOKEN": HF_TOKEN} if HF_TOKEN else {}, |
| } |
| ], |
| ) |
| print("β
Agent instance created successfully") |
| print("π§ Loading tools...") |
| await agent_instance.load_tools() |
| print("β
Tools loaded successfully") |
| except Exception as e: |
| print(f"β Error creating/loading agent: {str(e)}") |
| agent_instance = None |
| elif agent_instance is None: |
| print("β No HF_TOKEN available, cannot create agent") |
| else: |
| print("β
Using existing agent instance") |
|
|
| return agent_instance |
|
|
|
|
| def extract_tags_from_text(text: str) -> List[str]: |
| """Extract potential tags from discussion text""" |
| text_lower = text.lower() |
| explicit_tags = [] |
| tag_pattern = r"tags?:\s*([a-zA-Z0-9-_,\s]+)" |
| matches = re.findall(tag_pattern, text_lower) |
| for match in matches: |
| tags = [tag.strip() for tag in match.split(",")] |
| explicit_tags.extend(tags) |
| hashtag_pattern = r"#([a-zA-Z0-9-_]+)" |
| hashtag_matches = re.findall(hashtag_pattern, text_lower) |
| explicit_tags.extend(hashtag_matches) |
|
|
| mentioned_tags = [] |
| for tag in RECOGNIZED_TAGS: |
| if tag in text_lower: |
| mentioned_tags.append(tag) |
|
|
| all_tags = list(set(explicit_tags + mentioned_tags)) |
|
|
| valid_tags = [] |
| for tag in all_tags: |
| if tag in RECOGNIZED_TAGS or tag in explicit_tags: |
| valid_tags.append(tag) |
|
|
| return valid_tags |
|
|
|
|
| async def process_webhook_comment(webhook_data: Dict[str, Any]): |
| """Process webhook to detect and add tags""" |
| print("π·οΈ Starting process_webhook_comment...") |
|
|
| try: |
| comment_content = webhook_data["comment"]["content"] |
| discussion_title = webhook_data["discussion"]["title"] |
| repo_name = webhook_data["repo"]["name"] |
| discussion_num = webhook_data["discussion"]["num"] |
| comment_author = webhook_data["comment"]["author"].get("id", "unknown") |
|
|
| print(f"π Comment content: {comment_content}") |
| print(f"π° Discussion title: {discussion_title}") |
| print(f"π¦ Repository: {repo_name}") |
|
|
| comment_tags = extract_tags_from_text(comment_content) |
| title_tags = extract_tags_from_text(discussion_title) |
| all_tags = list(set(comment_tags + title_tags)) |
|
|
| print(f"π Comment tags found: {comment_tags}") |
| print(f"π Title tags found: {title_tags}") |
| print(f"π·οΈ All unique tags: {all_tags}") |
|
|
| result_messages = [] |
|
|
| if not all_tags: |
| msg = "No recognizable tags found in the discussion." |
| print(f"β {msg}") |
| result_messages.append(msg) |
| else: |
| print("π€ Getting agent instance...") |
| agent = await get_agent() |
| if not agent: |
| msg = "Error: Agent not configured (missing HF_TOKEN please check)" |
| print(f"β {msg}") |
| result_messages.append(msg) |
| else: |
| print("β
Agent instance obtained successfully") |
|
|
| try: |
| user_prompt = f""" |
| I need to add the following tags to the repository '{repo_name}': {", ".join(all_tags)} |
| For each tag, please: |
| 1. Check if the tag already exists on the repository using get_current_tags |
| 2. If the tag doesn't exist, add it using add_new_tag |
| 3. Provide a summary of what was done for each tag |
| Please process all {len(all_tags)} tags: {", ".join(all_tags)} |
| """ |
|
|
| print("π¬ Sending comprehensive prompt to agent...") |
| print(f"π Prompt: {user_prompt}") |
|
|
| conversation_result = [] |
|
|
| try: |
| async for item in agent.run(user_prompt): |
| item_str = str(item) |
| conversation_result.append(item_str) |
|
|
| if ( |
| "tool_call" in item_str.lower() |
| or "function" in item_str.lower() |
| ): |
| print(f"π§ Agent using tools: {item_str[:200]}...") |
| elif "content" in item_str and len(item_str) < 500: |
| print(f"π Agent response: {item_str}") |
|
|
| full_response = " ".join(conversation_result) |
| print(f"π Agent conversation completed successfully") |
|
|
| for tag in all_tags: |
| tag_mentioned = tag.lower() in full_response.lower() |
|
|
| if ( |
| "already exists" in full_response.lower() |
| and tag_mentioned |
| ): |
| msg = f"Tag '{tag}': Already exists" |
| elif ( |
| "pr" in full_response.lower() |
| or "pull request" in full_response.lower() |
| ): |
| if tag_mentioned: |
| msg = f"Tag '{tag}': PR created successfully" |
| else: |
| msg = ( |
| f"Tag '{tag}': Processed " |
| "(PR may have been created)" |
| ) |
| elif "success" in full_response.lower() and tag_mentioned: |
| msg = f"Tag '{tag}': Successfully processed" |
| elif "error" in full_response.lower() and tag_mentioned: |
| msg = f"Tag '{tag}': Error during processing" |
| else: |
| msg = f"Tag '{tag}': Processed by agent" |
|
|
| print(f"β
Result for tag '{tag}': {msg}") |
| result_messages.append(msg) |
|
|
| except Exception as agent_error: |
| print(f"β οΈ Agent streaming failed: {str(agent_error)}") |
| print("π Falling back to direct MCP tool calls...") |
|
|
| try: |
| import sys |
| import importlib.util |
|
|
| spec = importlib.util.spec_from_file_location( |
| "mcp_server", "./mcp_server.py" |
| ) |
| mcp_module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(mcp_module) |
| for tag in all_tags: |
| try: |
| print( |
| f"π§ Directly calling get_current_tags for '{tag}'" |
| ) |
| current_tags_result = mcp_module.get_current_tags( |
| repo_name |
| ) |
| print( |
| f"π Current tags result: {current_tags_result}" |
| ) |
|
|
| import json |
|
|
| tags_data = json.loads(current_tags_result) |
|
|
| if tags_data.get("status") == "success": |
| current_tags = tags_data.get("current_tags", []) |
| if tag in current_tags: |
| msg = f"Tag '{tag}': Already exists" |
| print(f"β
{msg}") |
| else: |
| print( |
| f"π§ Directly calling add_new_tag for '{tag}'" |
| ) |
| add_result = mcp_module.add_new_tag( |
| repo_name, tag |
| ) |
| print(f"π Add tag result: {add_result}") |
|
|
| add_data = json.loads(add_result) |
| if add_data.get("status") == "success": |
| pr_url = add_data.get("pr_url", "") |
| msg = f"Tag '{tag}': PR created - {pr_url}" |
| elif ( |
| add_data.get("status") |
| == "already_exists" |
| ): |
| msg = f"Tag '{tag}': Already exists" |
| else: |
| msg = f"Tag '{tag}': {add_data.get('message', 'Processed')}" |
| print(f"β
{msg}") |
| else: |
| error_msg = tags_data.get( |
| "error", "Unknown error" |
| ) |
| msg = f"Tag '{tag}': Error - {error_msg}" |
| print(f"β {msg}") |
|
|
| result_messages.append(msg) |
|
|
| except Exception as direct_error: |
| error_msg = f"Tag '{tag}': Direct call error - {str(direct_error)}" |
| print(f"β {error_msg}") |
| result_messages.append(error_msg) |
| except Exception as fallback_error: |
| error_msg = f"Fallback approach failed: {str(fallback_error)}" |
| print(f"β {error_msg}") |
| result_messages.append(error_msg) |
| except Exception as e: |
| error_msg = f"Error during agent processing {str(e)}" |
| print(f"β {error_msg}") |
| result_messages.append(error_msg) |
| base_url="https://huggingface.co" |
| discussion_url=f"{base_url}/{repo_name}/discussion/{discussion_num}" |
| interaction = { |
| "timestamp": datetime.now().isoformat(), |
| "repo": repo_name, |
| "discussion_title": discussion_title, |
| "discussion_num": discussion_num, |
| "discussion_url": discussion_url, |
| "original_comment": comment_content, |
| "comment_author": comment_author, |
| "detected_tags": all_tags, |
| "results": result_messages, |
| } |
| tag_operations_store.append(interaction) |
| final_result="|".join(result_messages) |
| print(f"πΎ Stored interaction and returning result: {final_result}") |
| return final_result |
|
|
| except Exception as e: |
| error_msg = f"β Fatal error in process_webhook_comment: {str(e)}" |
| print(error_msg) |
| return error_msg |
|
|
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint with basic information""" |
| return { |
| "name":"HF Tagging Bot", |
| "status":"running", |
| "description":"Webhook listener for automatic model tagging", |
| "endpoints":{ |
| "webhook":"/webhook", |
| "health":"/health", |
| "operations":"/operations" |
| } |
| } |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint for monitoring""" |
| agent=await get_agent() |
| return { |
| "status":"healthy", |
| "timestamp":datetime.now().isoformat(), |
| "components":{ |
| "webhook_secret":"configured" if WEBHOOK_SECRET else "missing", |
| "hf_token":"configured" if HF_TOKEN else "missing", |
| "mcp_agent":"ready" if agent else "not ready" |
| } |
| } |
|
|
| @app.get("/operations") |
| async def get_operations(): |
| """Get recent tag operations for monitoring""" |
| recent_ops=tag_operations_store[-50:] if tag_operations_store else [] |
| return { |
| "total_operations":len(tag_operations_store), |
| "recent_operations":recent_ops |
| } |
|
|
|
|
|
|
| @app.post("/webhook") |
| async def webhook_handler(request:Request, background_tasks:BackgroundTasks): |
| """ |
| Handle incoming webhooks from Hugging Face Hub |
| Following the pattern from: https://raw.githubusercontent.com/huggingface/hub-docs/refs/heads/main/docs/hub/webhooks-guide-discussion-bot.md |
| """ |
| print("π Webhook received!") |
| webhook_secret=request.headers.get("X-webhook-Secret") |
| if webhook_secret!=WEBHOOK_SECRET: |
| print("β Invalid webhook secret") |
| return {"error":"incorrect secret"} |
| payload=await request.json() |
| print(f"π₯ Received webhook payload: {json.dumps(payload, indent=2)}") |
| event=payload.get("event",{}) |
| scope=event.get("score") |
| action=event.get("action") |
|
|
| print(f"π Event details - scope: {scope}, action: {action}") |
| scope_check = scope == "discussion" |
| action_check = action == "create" |
| not_pr = not payload["discussion"]["isPullRequest"] |
| scope_check = scope_check and not_pr |
| print(f"β
not_pr: {not_pr}") |
| print(f"β
scope_check: {scope_check}") |
| print(f"β
action_check: {action_check}") |
| |
| if scope_check and action_check: |
| required_fields=['comment','discussion','repo'] |
| missing_fields=[field for field in required_fields if field not in payload] |
|
|
| if missing_fields: |
| error_msg = f"Missing required fields: {missing_fields}" |
| print(f"β {error_msg}") |
| return {"error": error_msg} |
| print(f"π Processing webhook for repo: {payload['repo']['name']}") |
| background_tasks.add_task(process_webhook_comment,payload) |
| return {"status":"processing"} |
| print(f"βοΈ Ignoring webhook - scope: {scope}, action: {action}") |
| return {"status": "ignored"} |
|
|
| @app.post("/simulate_webhook") |
| async def simulate_webhook(repo_name:str,discussion_title:str,comment_content:str)->str: |
| """Simulate webhook for testing purposes""" |
| if not all([repo_name,discussion_title,comment_content]): |
| return "please fill in all fields" |
| mock_payload={ |
| "event":{"action":"create","scope":"discussion.comment"}, |
| "comment":{"content":comment_content,"author":{"id":"test-user"},"id":"mock-comment-id","hidden":False}, |
| "discussion":{"title":discussion_title,"num":len(tag_operations_store)+1,"id":"mock-comment-id","status":"open","isPullRequest":False}, |
| "repo":{"name":repo_name,"type":"model","private":False} |
| } |
| response=await process_webhook_comment(mock_payload) |
| return f"β
Processed! Results: {response}" |
|
|
| def create_gradio_app(): |
| """Create Gradio interface""" |
| with gr.Blocks(title="HF Tagging Bot", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π·οΈ HF Tagging Bot Dashboard") |
| gr.Markdown("*Automatically adds tags to models when mentioned in discussions*") |
|
|
| gr.Markdown(""" |
| ## How it works: |
| - Monitors HuggingFace Hub discussions |
| - Detects tag mentions in comments (e.g., "tag: pytorch", |
| "#transformers") |
| - Automatically adds recognized tags to the model repository |
| - Supports common ML tags like: pytorch, tensorflow, |
| text-generation, etc. |
| """) |
| with gr.Column(): |
| sim_repo=gr.Textbox(label="Repository",value="burtenshaw/play-mcp-repo-bot",placeholder="username/model-name") |
| sim_title= gr.Textbox(label="Discussion Title",value="Add pytorch tag",placeholder="Discussion title") |
| sim_comment=gr.Textbox(label="comment",lines=3,value="This model should have tags: pytorch, text-generation",placeholder="Comment mentioning tags ...") |
| sim_btn=gr.Button("π·οΈ Test Tag Detection") |
| with gr.Column(): |
| sim_result=gr.Textbox(label="Result",lines=8) |
| sim_btn.click(fn=simulate_webhook,inputs=[sim_repo,sim_title,sim_comment],outputs=sim_result) |
| gr.Markdown(f"""## Recognized Tags: {",".join(sorted(RECOGNIZED_TAGS))}""") |
| return demo |
| gradio_app=create_gradio_app() |
| app=gr.mount_gradio_app(app,gradio_app,path="/gradio") |
|
|
| @app.get("/") |
| async def root_direct(): |
| return RedirectResponse(url="/gradio") |
|
|
| if __name__=="__main__": |
| print("π Starting HF Tagging Bot...") |
| print("π Dashboard: http://localhost:7860/gradio") |
| print("π Webhook: http://localhost:7860/webhook") |
| uvicorn.run("app:app",host="0.0.0.0",port=7860,reload=True) |