Spaces:
Runtime error
Runtime error
Update dataset.py
Browse files- dataset.py +1 -57
dataset.py
CHANGED
|
@@ -34,39 +34,7 @@ def load_dataset_cc(dataname, batch_size, hydra_path, condition):
|
|
| 34 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 35 |
cond_embs = model.encode(condition)
|
| 36 |
|
| 37 |
-
|
| 38 |
-
if not os.path.exists(f'{hydra_path}/graphs/{domain}/train.pt'):
|
| 39 |
-
|
| 40 |
-
data = torch.load(f'{hydra_path}/graphs/{domain}/{domain}.pt')
|
| 41 |
-
|
| 42 |
-
#fix seed
|
| 43 |
-
torch.manual_seed(0)
|
| 44 |
-
|
| 45 |
-
#random permute and split
|
| 46 |
-
n = len(data)
|
| 47 |
-
indices = torch.randperm(n)
|
| 48 |
-
|
| 49 |
-
if domain == 'eco':
|
| 50 |
-
train_indices = indices[:4].repeat(50)
|
| 51 |
-
val_indices = indices[4:5].repeat(50)
|
| 52 |
-
test_indices = indices[5:]
|
| 53 |
-
else:
|
| 54 |
-
train_indices = indices[:int(0.7 * n)]
|
| 55 |
-
val_indices = indices[int(0.7 * n):int(0.8 * n)]
|
| 56 |
-
test_indices = indices[int(0.8 * n):]
|
| 57 |
-
|
| 58 |
-
train_data = [data[_] for _ in train_indices]
|
| 59 |
-
val_data = [data[_] for _ in val_indices]
|
| 60 |
-
test_data = [data[_] for _ in test_indices]
|
| 61 |
-
|
| 62 |
-
torch.save(train_indices, f'{hydra_path}/graphs/{domain}/train_indices.pt')
|
| 63 |
-
torch.save(val_indices, f'{hydra_path}/graphs/{domain}/val_indices.pt')
|
| 64 |
-
torch.save(test_indices, f'{hydra_path}/graphs/{domain}/test_indices.pt')
|
| 65 |
-
|
| 66 |
-
torch.save(train_data, f'{hydra_path}/graphs/{domain}/train.pt')
|
| 67 |
-
torch.save(val_data, f'{hydra_path}/graphs/{domain}/val.pt')
|
| 68 |
-
torch.save(test_data, f'{hydra_path}/graphs/{domain}/test.pt')
|
| 69 |
-
|
| 70 |
|
| 71 |
train_data, val_data, test_data = [], [], []
|
| 72 |
|
|
@@ -99,30 +67,6 @@ def load_dataset_cc(dataname, batch_size, hydra_path, condition):
|
|
| 99 |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
|
| 100 |
|
| 101 |
|
| 102 |
-
elif dataname == 'all':
|
| 103 |
-
for i, domain in enumerate(domains):
|
| 104 |
-
train_d = torch.load(f'{hydra_path}/graphs/{domain}/train.pt')
|
| 105 |
-
val_d = torch.load(f'{hydra_path}/graphs/{domain}/val.pt')
|
| 106 |
-
test_d = torch.load(f'{hydra_path}/graphs/{domain}/test.pt')
|
| 107 |
-
|
| 108 |
-
train_indices = torch.load(f'{hydra_path}/graphs/{domain}/train_indices.pt')
|
| 109 |
-
val_indices = torch.load(f'{hydra_path}/graphs/{domain}/val_indices.pt')
|
| 110 |
-
test_indices = torch.load(f'{hydra_path}/graphs/{domain}/test_indices.pt')
|
| 111 |
-
|
| 112 |
-
# text_prompt = torch.load(f'{hydra_path}/graphs/{domain}/text_prompt_order.pt')
|
| 113 |
-
|
| 114 |
-
with open(f'{hydra_path}/graphs/{domain}/text_prompt_order.txt', 'r') as f:
|
| 115 |
-
text_prompt = f.readlines()
|
| 116 |
-
text_prompt = [x.strip() for x in text_prompt]
|
| 117 |
-
|
| 118 |
-
print(domain, text_prompt[0])
|
| 119 |
-
|
| 120 |
-
text_embs = model.encode(text_prompt)
|
| 121 |
-
|
| 122 |
-
train_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)])
|
| 123 |
-
val_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_d, val_indices)])
|
| 124 |
-
test_data.extend([arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_d, test_indices)])
|
| 125 |
-
print(i, domain, len(train_data), len(val_data), len(test_data))
|
| 126 |
|
| 127 |
print('Size of dataset', len(train_data), len(val_data), len(test_data))
|
| 128 |
|
|
|
|
| 34 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 35 |
cond_embs = model.encode(condition)
|
| 36 |
|
| 37 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
train_data, val_data, test_data = [], [], []
|
| 40 |
|
|
|
|
| 67 |
test_data = [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(train_d, train_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(val_data, val_indices)] + [arrange_data(d, text_embs[ind.item()], ind.item()) for d, ind in zip(test_data, test_indices)]
|
| 68 |
|
| 69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
print('Size of dataset', len(train_data), len(val_data), len(test_data))
|
| 72 |
|