| import base64 |
| import os |
| import tempfile |
| from io import BytesIO |
|
|
| import gradio as gr |
| import pandas as pd |
| import requests |
|
|
|
|
| API_BASE_URL = os.getenv("MULTIMODAL_API_BASE_URL", "http://127.0.0.1:7861") |
| EXTRACTION_API_URL = os.getenv( |
| "MULTIMODAL_API_URL", |
| f"{API_BASE_URL.rstrip('/')}/information_extraction/", |
| ) |
| MAPPING_API_URL = os.getenv( |
| "MULTIMODAL_MAPPING_API_URL", |
| f"{API_BASE_URL.rstrip('/')}/mapping/", |
| ) |
|
|
|
|
| def process_zip(zip_file, employee_code, debug=False): |
| if zip_file is None: |
| raise gr.Error("Please upload a ZIP file.") |
| if not employee_code or not employee_code.strip(): |
| raise gr.Error("Please enter an employee code.") |
|
|
| with open(zip_file, "rb") as file_obj: |
| response = requests.post( |
| EXTRACTION_API_URL, |
| files={ |
| "file": ( |
| os.path.basename(zip_file), |
| file_obj, |
| "application/zip", |
| ) |
| }, |
| data={ |
| "employee_code": employee_code.strip(), |
| "debug": str(debug).lower(), |
| }, |
| timeout=300, |
| ) |
|
|
| try: |
| payload = response.json() |
| except ValueError as exc: |
| raise gr.Error("The API returned an invalid response.") from exc |
|
|
| if response.status_code != 200: |
| error_message = payload.get("detail") or payload.get("message") or "Request failed." |
| raise gr.Error(error_message) |
|
|
| excel_base64 = payload.get("excel_data_base64") |
| if not excel_base64: |
| raise gr.Error("The API response did not include an Excel file.") |
|
|
| excel_bytes = base64.b64decode(excel_base64) |
| dataframe = pd.read_excel(BytesIO(excel_bytes)) |
|
|
| output_path = os.path.join( |
| tempfile.gettempdir(), |
| f"extraction_result_{employee_code.strip()}.xlsx", |
| ) |
| with open(output_path, "wb") as excel_file: |
| excel_file.write(excel_bytes) |
|
|
| status = ( |
| f"{payload.get('message', 'Processing completed.')} " |
| f"Time: {payload.get('duration', 'N/A')}s." |
| ) |
|
|
| return status, dataframe, output_path |
|
|
|
|
| def process_mapping(product_list, dense_weight=0.7, sparse_weight=0.3, normalize=True): |
| if not product_list or not product_list.strip(): |
| raise gr.Error("Please enter at least one product name.") |
|
|
| response = requests.post( |
| MAPPING_API_URL, |
| data={ |
| "product_list": product_list.strip(), |
| "dense_weight": dense_weight, |
| "sparse_weight": sparse_weight, |
| "normalize": str(normalize).lower(), |
| }, |
| timeout=300, |
| ) |
|
|
| try: |
| payload = response.json() |
| except ValueError as exc: |
| raise gr.Error("The API returned an invalid response.") from exc |
|
|
| if response.status_code != 200: |
| error_message = payload.get("detail") or payload.get("message") or "Request failed." |
| raise gr.Error(error_message) |
|
|
| results = payload.get("results") |
| if not isinstance(results, list): |
| raise gr.Error("The API response did not include mapping results.") |
|
|
| dataframe = pd.DataFrame(results) |
| if dataframe.empty: |
| dataframe = pd.DataFrame( |
| columns=["original_product_name", "top_1", "top_2", "top_3", "top_4", "top_5"] |
| ) |
| else: |
| preferred_columns = [ |
| "original_product_name", |
| "top_1", |
| "top_2", |
| "top_3", |
| "top_4", |
| "top_5", |
| ] |
| dataframe = dataframe.reindex(columns=preferred_columns) |
|
|
| total_products = payload.get("total_products", len(dataframe)) |
| duration = payload.get("api_duration", payload.get("duration", "N/A")) |
| status = f"Mapped {total_products} products. Time: {duration}s." |
|
|
| return status, dataframe |
|
|
|
|
| with gr.Blocks(title="Multimodal OCR and Product Mapping") as demo: |
| gr.Markdown( |
| """ |
| # Multimodal OCR and Product Mapping Interface |
| ## Run information extraction or product mapping from a single interface. |
| """ |
| ) |
|
|
| with gr.Tabs(): |
| with gr.Tab("Information Extraction"): |
| with gr.Row(): |
| zip_input = gr.File( |
| label="ZIP File", |
| file_types=[".zip"], |
| type="filepath", |
| ) |
| employee_code_input = gr.Textbox( |
| label="Employee Code", |
| placeholder="e.g. admin", |
| ) |
|
|
| debug_input = gr.Checkbox(label="Debug mode", value=False) |
| submit_button = gr.Button("Process", variant="primary") |
|
|
| status_output = gr.Textbox(label="Status", interactive=False) |
| excel_preview = gr.Dataframe(label="Excel Preview", interactive=False) |
| excel_download = gr.File(label="Download Excel", interactive=False) |
|
|
| submit_button.click( |
| fn=process_zip, |
| inputs=[zip_input, employee_code_input, debug_input], |
| outputs=[status_output, excel_preview, excel_download], |
| ) |
|
|
| with gr.Tab("Mapping"): |
| product_input = gr.Textbox( |
| label="Product List", |
| placeholder="One product per line", |
| lines=10, |
| ) |
|
|
| with gr.Row(): |
| dense_weight_input = gr.Slider( |
| minimum=0, |
| maximum=1, |
| value=0.7, |
| step=0.1, |
| label="Dense Weight", |
| ) |
| sparse_weight_input = gr.Slider( |
| minimum=0, |
| maximum=1, |
| value=0.3, |
| step=0.1, |
| label="Sparse Weight", |
| ) |
| gr.Markdown("---") |
| gr.Markdown("#### 💡 Ví dụ") |
| gr.Examples( |
| examples=[ |
| ["Đèn LED ốp trần 18W\nBóng LED E27 9W\nĐèn downlight âm trần"], |
| ["Bóng Trụ 30W\nPhích nước 2L\nVợt bắt muỗi"], |
| ["Đèn năng lượng mặt trời\nĐèn sạc tích điện\nBóng compact 20W"], |
| ], |
| inputs=[product_input], |
| label="Click để thử các ví dụ" |
| ) |
| normalize_input = gr.Checkbox(label="Normalize Query", value=True) |
| mapping_button = gr.Button("Run Mapping", variant="primary") |
|
|
| mapping_status = gr.Textbox(label="Status", interactive=False) |
| mapping_preview = gr.Dataframe(label="Mapping Preview", interactive=False) |
|
|
| mapping_button.click( |
| fn=process_mapping, |
| inputs=[ |
| product_input, |
| dense_weight_input, |
| sparse_weight_input, |
| normalize_input, |
| ], |
| outputs=[mapping_status, mapping_preview], |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|