Spaces:
Runtime error
Runtime error
Update demo_model.py
Browse files- demo_model.py +1 -1
demo_model.py
CHANGED
|
@@ -79,7 +79,7 @@ class LGGMText2Graph_Demo(pl.LightningModule):
|
|
| 79 |
with torch.no_grad():
|
| 80 |
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
|
| 81 |
|
| 82 |
-
samples = self.sample_batch(
|
| 83 |
|
| 84 |
nx_graphs = []
|
| 85 |
for graph in samples:
|
|
|
|
| 79 |
with torch.no_grad():
|
| 80 |
prompt_emb = self.text_encoder(**encoded_input).hidden_states[-1][:, 0]
|
| 81 |
|
| 82 |
+
samples = self.sample_batch(3, cond_emb = prompt_emb.to(self.device), num_nodes = num_nodes)
|
| 83 |
|
| 84 |
nx_graphs = []
|
| 85 |
for graph in samples:
|