| from .base_agent import BaseAgent |
| from prompt.template import CREATE_CHART_PROMPT |
|
|
| class Chart(BaseAgent): |
| def __init__(self, llm): |
| super().__init__(llm) |
| |
| def create_single_chart(self, paper_content: str, existing_charts: str, user_prompt: str=''): |
| prompt = CREATE_CHART_PROMPT.format(paper_content=paper_content, existing_charts=existing_charts, user_prompt=user_prompt) |
| return self.llm.generate(prompt) |
|
|
| def create_charts(self, paper_content: str, chart_num: int, user_prompt: str=''): |
| existing_charts = '' |
| charts = [] |
| for i in range(chart_num): |
| chart = self.create_single_chart(paper_content, existing_charts, user_prompt) |
| charts.append(chart) |
| existing_charts = '\n---\n'.join(charts) |
| return charts |
|
|