| from __future__ import annotations |
|
|
| from argparse import ArgumentParser |
|
|
| import datasets |
| import gradio as gr |
| import numpy as np |
| import openai |
|
|
| from dataset_creation.generate_txt_dataset import generate |
|
|
|
|
| def main(openai_model: str): |
| dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train") |
| captions = dataset[np.random.permutation(len(dataset))]["TEXT"] |
| index = 0 |
|
|
| def click_random(): |
| nonlocal index |
| output = captions[index] |
| index = (index + 1) % len(captions) |
| return output |
|
|
| def click_generate(input: str): |
| if input == "": |
| raise gr.Error("Input caption is missing!") |
| edit_output = generate(openai_model, input) |
| if edit_output is None: |
| return "Failed :(", "Failed :(" |
| return edit_output |
|
|
| with gr.Blocks(css="footer {visibility: hidden}") as demo: |
| txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") |
| txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False) |
| txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False) |
|
|
| with gr.Row(): |
| clear_btn = gr.Button("Clear") |
| random_btn = gr.Button("Random Input") |
| generate_btn = gr.Button("Generate Instruction + Edited Caption") |
|
|
| clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output]) |
| random_btn.click(fn=click_random, inputs=[], outputs=[txt_input]) |
| generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output]) |
|
|
| demo.launch(share=True) |
|
|
|
|
| if __name__ == "__main__": |
| parser = ArgumentParser() |
| parser.add_argument("--openai-api-key", required=True, type=str) |
| parser.add_argument("--openai-model", required=True, type=str) |
| args = parser.parse_args() |
| openai.api_key = args.openai_api_key |
| main(args.openai_model) |
|
|