Commit ·
e19e1b1
1
Parent(s): 2cf93bf
Upload 4 files
Browse files- pretrain.py +0 -75
- requirements.txt +8 -0
pretrain.py
CHANGED
|
@@ -34,18 +34,11 @@ import pickle as pkl
|
|
| 34 |
from sophia import SophiaG
|
| 35 |
|
| 36 |
|
| 37 |
-
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
| 38 |
|
| 39 |
-
# # constants
|
| 40 |
-
|
| 41 |
-
# NUM_BATCHES = int(1e5)
|
| 42 |
-
# BATCH_SIZE = 4
|
| 43 |
GRADIENT_ACCUMULATE_EVERY = 4
|
| 44 |
LEARNING_RATE = 1e-4
|
| 45 |
VALIDATE_EVERY = 100
|
| 46 |
GENERATE_EVERY = 500
|
| 47 |
-
# GENERATE_LENGTH = 2048
|
| 48 |
-
# SEQ_LEN = 4096
|
| 49 |
|
| 50 |
|
| 51 |
parser = argparse.ArgumentParser()
|
|
@@ -65,9 +58,6 @@ parser.add_argument("--ckpt_dir", type=str, default='./ckpts/', help='Directory
|
|
| 65 |
parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.')
|
| 66 |
|
| 67 |
args = parser.parse_args()
|
| 68 |
-
# rank = int(os.environ["RANK"])
|
| 69 |
-
# local_rank = args.local_rank
|
| 70 |
-
# is_master = local_rank == 0
|
| 71 |
|
| 72 |
SEED = args.seed
|
| 73 |
EPOCHS = args.epoch
|
|
@@ -86,14 +76,6 @@ POS_EMBED_USING = args.pos_embed
|
|
| 86 |
model_name = args.model_name
|
| 87 |
ckpt_dir = args.ckpt_dir
|
| 88 |
|
| 89 |
-
# dist.init_process_group(backend='nccl')
|
| 90 |
-
# torch.cuda.set_device(local_rank)
|
| 91 |
-
# device = torch.device("cuda", local_rank)
|
| 92 |
-
# world_size = torch.distributed.get_world_size()
|
| 93 |
-
|
| 94 |
-
# seed_all(SEED + torch.distributed.get_rank())
|
| 95 |
-
|
| 96 |
-
|
| 97 |
|
| 98 |
# helpers
|
| 99 |
|
|
@@ -127,27 +109,7 @@ model = PerformerLM(
|
|
| 127 |
model = AutoregressiveWrapper(model)
|
| 128 |
model.cuda()
|
| 129 |
|
| 130 |
-
|
| 131 |
-
|
| 132 |
# prepare sc data
|
| 133 |
-
|
| 134 |
-
class SCDataset(Dataset):
|
| 135 |
-
def __init__(self, data, label):
|
| 136 |
-
super().__init__()
|
| 137 |
-
self.data = data
|
| 138 |
-
self.label = label
|
| 139 |
-
|
| 140 |
-
def __getitem__(self, index):
|
| 141 |
-
rand_start = random.randint(0, self.data.shape[0]-1)
|
| 142 |
-
full_seq = self.data[rand_start].toarray()[0]
|
| 143 |
-
full_seq[full_seq > (CLASS - 2)] = CLASS - 2
|
| 144 |
-
full_seq = torch.from_numpy(full_seq).long()
|
| 145 |
-
full_seq = torch.cat((full_seq, torch.tensor([0]))).to(device)
|
| 146 |
-
seq_label = self.label[rand_start]
|
| 147 |
-
return full_seq, seq_label
|
| 148 |
-
|
| 149 |
-
def __len__(self):
|
| 150 |
-
return self.data.shape[0]
|
| 151 |
|
| 152 |
class SCDatasetPretrain(Dataset):
|
| 153 |
def __init__(self, data, seq_len):
|
|
@@ -169,19 +131,8 @@ class SCDatasetPretrain(Dataset):
|
|
| 169 |
|
| 170 |
def __len__(self):
|
| 171 |
return self.data.shape[0]
|
| 172 |
-
|
| 173 |
|
| 174 |
data = sc.read_h5ad(args.data_path)
|
| 175 |
-
#data = data[:1000, :]
|
| 176 |
-
# label_dict, label = np.unique(np.array(data.obs['cell_type']), return_inverse=True) # Convert strings categorical to integrate categorical, and label_dict[label] can be restored
|
| 177 |
-
# #store the label dict and label for prediction
|
| 178 |
-
# with open('label_dict', 'wb') as fp:
|
| 179 |
-
# pkl.dump(label_dict, fp)
|
| 180 |
-
# with open('label', 'wb') as fp:
|
| 181 |
-
# pkl.dump(label, fp)
|
| 182 |
-
# class_num = np.unique(label, return_counts=True)[1].tolist()
|
| 183 |
-
# class_weight = torch.tensor([(1 - (x / sum(class_num))) ** 2 for x in class_num])
|
| 184 |
-
# label = torch.from_numpy(label)
|
| 185 |
data = data.X
|
| 186 |
|
| 187 |
acc = []
|
|
@@ -190,18 +141,6 @@ f1w = []
|
|
| 190 |
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
|
| 191 |
pred_list = pd.Series(['un'] * data.shape[0])
|
| 192 |
|
| 193 |
-
# sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=SEED)
|
| 194 |
-
# for index_train in sss.split(data):
|
| 195 |
-
# data_train = data[index_train]
|
| 196 |
-
# data_val = data[index_val]
|
| 197 |
-
# train_dataset = SCDatasetPretrain(data_train, SEQ_LEN)
|
| 198 |
-
# val_dataset = SCDatasetPretrain(data_val, SEQ_LEN)
|
| 199 |
-
|
| 200 |
-
# train_sampler = DistributedSampler(train_dataset)
|
| 201 |
-
# val_sampler = DistributedSampler(val_dataset)
|
| 202 |
-
# train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=train_sampler)
|
| 203 |
-
# val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, sampler=val_sampler)
|
| 204 |
-
|
| 205 |
index_train = int(data.shape[0]*0.8)
|
| 206 |
data_train = data[:index_train]
|
| 207 |
data_val = data[index_train:]
|
|
@@ -210,15 +149,11 @@ val_dataset = SCDatasetPretrain(data_val, SEQ_LEN)
|
|
| 210 |
|
| 211 |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
| 212 |
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
| 213 |
-
# train_loader = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
|
| 214 |
-
# val_loader = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))
|
| 215 |
|
| 216 |
# optimizer
|
| 217 |
|
| 218 |
optim = SophiaG(model.parameters(), lr=2e-4,
|
| 219 |
betas=(0.965, 0.99), rho = 0.01, weight_decay=1e-1)
|
| 220 |
-
# optim = torch.optim.SGD(model.parameters(), lr=1e-8, momentum=0.9)
|
| 221 |
-
# optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
| 222 |
scaler = GradScaler()
|
| 223 |
|
| 224 |
# training
|
|
@@ -244,14 +179,6 @@ for i in tqdm(range(EPOCHS), mininterval=10., desc='training'):
|
|
| 244 |
scaler.update()
|
| 245 |
optim.zero_grad()
|
| 246 |
|
| 247 |
-
# if i % VALIDATE_EVERY == 0:
|
| 248 |
-
# model.eval()
|
| 249 |
-
# with torch.no_grad():
|
| 250 |
-
# #loss = model(next(val_loader), return_loss = True)
|
| 251 |
-
# for index, data_batch in enumerate(tqdm(val_loader)):
|
| 252 |
-
# loss = model(data_batch, return_loss = True)
|
| 253 |
-
# print(f'validation loss: {loss.item()}')
|
| 254 |
-
|
| 255 |
if i % GENERATE_EVERY == 0 and i != 0:
|
| 256 |
model.eval()
|
| 257 |
inp = random.choice(val_dataset)[:-1]
|
|
@@ -266,5 +193,3 @@ for i in tqdm(range(EPOCHS), mininterval=10., desc='training'):
|
|
| 266 |
print('save model')
|
| 267 |
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optim.state_dict()}
|
| 268 |
torch.save(checkpoint, os.path.join(ckpt_dir, 'model_gene_attn.pth'))
|
| 269 |
-
|
| 270 |
-
a=1
|
|
|
|
| 34 |
from sophia import SophiaG
|
| 35 |
|
| 36 |
|
|
|
|
| 37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
GRADIENT_ACCUMULATE_EVERY = 4
|
| 39 |
LEARNING_RATE = 1e-4
|
| 40 |
VALIDATE_EVERY = 100
|
| 41 |
GENERATE_EVERY = 500
|
|
|
|
|
|
|
| 42 |
|
| 43 |
|
| 44 |
parser = argparse.ArgumentParser()
|
|
|
|
| 58 |
parser.add_argument("--model_name", type=str, default='finetune', help='Finetuned model name.')
|
| 59 |
|
| 60 |
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
SEED = args.seed
|
| 63 |
EPOCHS = args.epoch
|
|
|
|
| 76 |
model_name = args.model_name
|
| 77 |
ckpt_dir = args.ckpt_dir
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
# helpers
|
| 81 |
|
|
|
|
| 109 |
model = AutoregressiveWrapper(model)
|
| 110 |
model.cuda()
|
| 111 |
|
|
|
|
|
|
|
| 112 |
# prepare sc data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
class SCDatasetPretrain(Dataset):
|
| 115 |
def __init__(self, data, seq_len):
|
|
|
|
| 131 |
|
| 132 |
def __len__(self):
|
| 133 |
return self.data.shape[0]
|
|
|
|
| 134 |
|
| 135 |
data = sc.read_h5ad(args.data_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
data = data.X
|
| 137 |
|
| 138 |
acc = []
|
|
|
|
| 141 |
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED)
|
| 142 |
pred_list = pd.Series(['un'] * data.shape[0])
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
index_train = int(data.shape[0]*0.8)
|
| 145 |
data_train = data[:index_train]
|
| 146 |
data_val = data[index_train:]
|
|
|
|
| 149 |
|
| 150 |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
| 151 |
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# optimizer
|
| 154 |
|
| 155 |
optim = SophiaG(model.parameters(), lr=2e-4,
|
| 156 |
betas=(0.965, 0.99), rho = 0.01, weight_decay=1e-1)
|
|
|
|
|
|
|
| 157 |
scaler = GradScaler()
|
| 158 |
|
| 159 |
# training
|
|
|
|
| 179 |
scaler.update()
|
| 180 |
optim.zero_grad()
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
if i % GENERATE_EVERY == 0 and i != 0:
|
| 183 |
model.eval()
|
| 184 |
inp = random.choice(val_dataset)[:-1]
|
|
|
|
| 193 |
print('save model')
|
| 194 |
checkpoint = {'state_dict': model.state_dict(),'optimizer' :optim.state_dict()}
|
| 195 |
torch.save(checkpoint, os.path.join(ckpt_dir, 'model_gene_attn.pth'))
|
|
|
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==1.8.1
|
| 2 |
+
torchvision==0.9.1
|
| 3 |
+
transformers==4.6.1
|
| 4 |
+
scanpy==1.7.2
|
| 5 |
+
scikit-learn==0.24.2
|
| 6 |
+
scipy==1.5.4
|
| 7 |
+
numpy==1.19.2
|
| 8 |
+
pandas==1.1.5
|