Spaces:
Runtime error
Runtime error
| from omegaconf import OmegaConf | |
| import gradio as gr | |
| from dataset import init_dataset, compute_input_output_dims | |
| from extra_features import ExtraFeatures | |
| from demo_model import LGGMText2Graph_Demo | |
| from analysis.spectre_utils import CrossDomainSamplingMetrics | |
| import networkx as nx | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torch | |
| cfg = OmegaConf.load('./config.yaml') | |
| hydra_path = '.' | |
| data_loaders, num_classes, max_n_nodes, nodes_dist, edge_types, node_types, n_nodes, cond_dims, cond_emb = init_dataset(cfg.dataset.name, cfg.train.batch_size, hydra_path, cfg.general.condition, cfg.model.transition) | |
| extra_features = ExtraFeatures(cfg.model.extra_features, max_n_nodes) | |
| input_dims, output_dims = compute_input_output_dims(data_loaders['train'], extra_features) | |
| sampling_metrics = CrossDomainSamplingMetrics(data_loaders) | |
| # model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device('cpu')) | |
| model = LGGMText2Graph_Demo.load_from_checkpoint('cc-deg.ckpt', map_location=torch.device("cpu")) | |
| model.init_prompt_encoder_pretrained() | |
| def calculate_average_degree(graph): | |
| num_nodes = graph.number_of_nodes() | |
| num_edges = graph.number_of_edges() | |
| return (2 * num_edges) / num_nodes if num_nodes > 0 else 0 | |
| def predict(text, num_nodes = None): | |
| # Assuming model.generate and other processes are defined as before | |
| graphs = model.generate_pretrained(text, int(num_nodes)) | |
| ccs = [] | |
| degs = [] | |
| images = [] | |
| for g in graphs: | |
| ccs.append(nx.average_clustering(g)) | |
| degs.append(calculate_average_degree(g)) | |
| fig, ax = plt.subplots() | |
| nx.draw(g, ax=ax) | |
| fig.canvas.draw() | |
| image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close(fig) | |
| images.append(image) | |
| avg_deg = np.mean(degs) | |
| avg_cc = np.mean(ccs) | |
| return images[0], images[1], images[2], ccs[0], ccs[1], ccs[2], degs[0], degs[1], degs[2], avg_cc, avg_deg | |
| def clear(input_text): | |
| return None, None, None, None, None, None, None, None, None, None, None | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Text2Graph Generation Demo") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_text = gr.Textbox(label="Input your text prompt here", placeholder="Type here...") | |
| with gr.Column(): | |
| input_num = gr.Slider(5, 100, value=25, step = 1, label="Count", info="Number of nodes in the graph to be generated") | |
| with gr.Column(): | |
| gr.Markdown("### Suggested Prompts") | |
| gr.Markdown("1. Create a complex network with high clustering coefficient.\n2. Create a graph with extremely low number of triangles.\n 3. Please give me a Power Network with extremely low number of triangles but with medium level of average degree.") | |
| with gr.Row() as output_row: | |
| output_images = [gr.Image(label = f"Generated Network #{_}") for _ in range(3)] | |
| with gr.Row(): | |
| output_texts_cc = [gr.Textbox(label=f"CC #{_}") for _ in range(3)] | |
| with gr.Row(): | |
| output_texts_deg = [gr.Textbox(label=f"DEG #{_}") for _ in range(3)] | |
| with gr.Row(): | |
| avg_cc_text = gr.Textbox(label="Average Clustering Coefficient") | |
| avg_deg_text = gr.Textbox(label="Average Degree") | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit") | |
| clear_button = gr.Button("Clear") | |
| # Change function is linked to the submit button | |
| submit_button.click(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text]) | |
| input_text.submit(fn=predict, inputs=[input_text, input_num], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text]) | |
| # Clear function resets the text input and clears the outputs | |
| clear_button.click(fn=clear, inputs=[input_text], outputs=output_images + output_texts_cc + output_texts_deg + [avg_cc_text, avg_deg_text]) | |
| demo.launch() |