| import base64 |
| import copy |
| from datetime import datetime |
| import json |
| import fire |
| import os |
| import pathlib |
|
|
| from poster.figures import extract_figures |
| from poster.poster import ( |
| generate_html_v2, |
| generate_poster_v3, |
| replace_figures_in_poster, |
| replace_figures_size_in_poster, |
| ) |
|
|
|
|
| def generate_paper_poster( |
| url: str, |
| pdf: str, |
| vendor: str = "openai", |
| model: str = "gpt-4o-mini", |
| text_prompt: str = "", |
| figures_prompt: str = "", |
| output: str = "poster.json", |
| ): |
| """Generate a paper poster |
| |
| Args: |
| url: URL of the PDF file |
| pdf: Local path of the PDF file |
| model: Name of the model to use, default is gpt-4o-mini |
| text_prompt: Text prompt template, |
| figures_prompt: Figures prompt template, |
| output: Output file path, default is poster.json |
| """ |
| pdf_stem = pdf.replace(".pdf", "") |
| figures_cache = f"{pdf_stem}_figures.json" |
| figures_cap_cache = f"{pdf_stem}_figures_cap.json" |
|
|
| figures = [] |
| |
| print("开始提取图片...") |
| if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache): |
| print(f"使用缓存的图片: {figures_cache}") |
| with open(figures_cache, "r") as f: |
| figures = json.load(f) |
| |
| |
| else: |
| figures_img = extract_figures(url, pdf, task="figure") |
| figures_table = extract_figures(url, pdf, task="table") |
| |
| |
| threshold = 0.75 |
| |
| figures = [ |
| image |
| for image, score in figures_img + figures_table |
| if score >= threshold |
| ] |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| with open(figures_cache, "w") as f: |
| json.dump(figures, f, ensure_ascii=False) |
| |
| |
|
|
| print("开始生成海报...") |
| max_attempts = 3 |
| attempt = 0 |
| while True: |
| try: |
| result = generate_poster_v3( |
| vendor, model, text_prompt, figures_prompt, pdf, figures, figures |
| ) |
|
|
| poster = result["image_based_poster"] |
| backup_poster = copy.deepcopy(poster) |
|
|
| poster = replace_figures_in_poster(poster, figures) |
|
|
| |
| |
|
|
| poster_size = replace_figures_size_in_poster(backup_poster, figures) |
| print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Now generating HTML...") |
| result = generate_html_v2(vendor, model, poster_size, figures) |
|
|
| html = result["html_with_figures"] |
|
|
| |
| |
| |
| print("海报生成成功!") |
| return poster, html |
|
|
| except Exception as e: |
| if ( |
| "content management policy" in str(e) |
| or "message larger than max" in str(e) |
| or "exceeds the maximum length" in str(e) |
| or "maximum context length" in str(e) |
| or "Input is too long" in str(e) |
| or "image exceeds 5 MB" in str(e) |
| or "too many total text bytes" in str(e) |
| or "Range of input length" in str(e) |
| or "Invalid text" in str(e) |
| ): |
| raise |
| print(f"处理文件 {pdf} 时出错: {e}") |
| attempt += 1 |
| if attempt > max_attempts: |
| return None, None |
|
|
|
|
|
|
| if __name__ == "__main__": |
| fire.Fire(generate_paper_poster) |
|
|