|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import subprocess
|
| GPU_NUMBER = [0,1,2,3]
|
| os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
| os.environ["NCCL_DEBUG"] = "INFO"
|
| os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
| os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
|
|
|
|
|
| import pyarrow
|
| import ray
|
| from ray import tune
|
| from ray.tune import ExperimentAnalysis
|
| from ray.tune.suggest.hyperopt import HyperOptSearch
|
| ray.shutdown()
|
| runtime_env = {"conda": "base",
|
| "env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
|
| ray.init(runtime_env=runtime_env)
|
|
|
| def initialize_ray_with_check(ip_address):
|
| """
|
| Initialize Ray with a specified IP address and check its status and accessibility.
|
|
|
| Args:
|
| - ip_address (str): The IP address (with port) to initialize Ray.
|
|
|
| Returns:
|
| - bool: True if initialization was successful and dashboard is accessible, False otherwise.
|
| """
|
| try:
|
| ray.init(address=ip_address)
|
| print(ray.nodes())
|
|
|
| services = ray.get_webui_url()
|
| if not services:
|
| raise RuntimeError("Ray dashboard is not accessible.")
|
| else:
|
| print(f"Ray dashboard is accessible at: {services}")
|
| return True
|
| except Exception as e:
|
| print(f"Error initializing Ray: {e}")
|
| return False
|
|
|
|
|
| ip = 'your_ip:xxxx'
|
| if initialize_ray_with_check(ip):
|
| print("Ray initialized successfully.")
|
| else:
|
| print("Error during Ray initialization.")
|
|
|
| import datetime
|
| import numpy as np
|
| import pandas as pd
|
| import random
|
| import seaborn as sns; sns.set()
|
| from collections import Counter
|
| from datasets import load_from_disk
|
| from scipy.stats import ranksums
|
| from sklearn.metrics import accuracy_score
|
| from transformers import BertForSequenceClassification
|
| from transformers import Trainer
|
| from transformers.training_args import TrainingArguments
|
|
|
| from geneformer import DataCollatorForCellClassification
|
|
|
|
|
| num_proc=30
|
|
|
|
|
|
|
|
|
|
|
|
|
| train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
|
|
|
|
|
| def if_cell_type(example):
|
| return example["cell_type"].startswith("Cardiomyocyte")
|
|
|
| trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
|
|
|
|
|
| target_names = ["healthy", "disease1", "disease2"]
|
| target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
|
|
| trainset_v3 = trainset_v2.rename_column("disease","label")
|
|
|
|
|
| def classes_to_ids(example):
|
| example["label"] = target_name_id_dict[example["label"]]
|
| return example
|
|
|
| trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
|
|
|
|
|
| indiv_set = set(trainset_v4["individual"])
|
| random.seed(42)
|
| train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
|
| eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
|
| valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
|
| test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
|
|
|
| def if_train(example):
|
| return example["individual"] in train_indiv
|
|
|
| classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
|
|
|
| def if_valid(example):
|
| return example["individual"] in valid_indiv
|
|
|
| classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
|
|
|
|
|
| current_date = datetime.datetime.now()
|
| datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
| output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
|
|
|
|
|
| saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
| if os.path.isfile(saved_model_test) == True:
|
| raise Exception("Model already saved to this directory.")
|
|
|
|
|
| subprocess.call(f'mkdir {output_dir}', shell=True)
|
|
|
|
|
|
|
| freeze_layers = 2
|
|
|
| geneformer_batch_size = 12
|
|
|
| epochs = 1
|
|
|
| logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
|
|
|
|
|
| def model_init():
|
| model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
|
| num_labels=len(target_names),
|
| output_attentions = False,
|
| output_hidden_states = False)
|
| if freeze_layers is not None:
|
| modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
| for module in modules_to_freeze:
|
| for param in module.parameters():
|
| param.requires_grad = False
|
|
|
| model = model.to("cuda:0")
|
| return model
|
|
|
|
|
|
|
| def compute_metrics(pred):
|
| labels = pred.label_ids
|
| preds = pred.predictions.argmax(-1)
|
|
|
| acc = accuracy_score(labels, preds)
|
| return {
|
| 'accuracy': acc,
|
| }
|
|
|
|
|
| training_args = {
|
| "do_train": True,
|
| "do_eval": True,
|
| "evaluation_strategy": "steps",
|
| "eval_steps": logging_steps,
|
| "logging_steps": logging_steps,
|
| "group_by_length": True,
|
| "length_column_name": "length",
|
| "disable_tqdm": True,
|
| "skip_memory_metrics": True,
|
| "per_device_train_batch_size": geneformer_batch_size,
|
| "per_device_eval_batch_size": geneformer_batch_size,
|
| "num_train_epochs": epochs,
|
| "load_best_model_at_end": True,
|
| "output_dir": output_dir,
|
| }
|
|
|
| training_args_init = TrainingArguments(**training_args)
|
|
|
|
|
| trainer = Trainer(
|
| model_init=model_init,
|
| args=training_args_init,
|
| data_collator=DataCollatorForCellClassification(),
|
| train_dataset=classifier_trainset,
|
| eval_dataset=classifier_validset,
|
| compute_metrics=compute_metrics,
|
| )
|
|
|
|
|
| ray_config = {
|
| "num_train_epochs": tune.choice([epochs]),
|
| "learning_rate": tune.loguniform(1e-6, 1e-3),
|
| "weight_decay": tune.uniform(0.0, 0.3),
|
| "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
| "warmup_steps": tune.uniform(100, 2000),
|
| "seed": tune.uniform(0,100),
|
| "per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
| }
|
|
|
| hyperopt_search = HyperOptSearch(
|
| metric="eval_accuracy", mode="max")
|
|
|
|
|
| trainer.hyperparameter_search(
|
| direction="maximize",
|
| backend="ray",
|
| resources_per_trial={"cpu":8,"gpu":1},
|
| hp_space=lambda _: ray_config,
|
| search_alg=hyperopt_search,
|
| n_trials=100,
|
| progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
| sort_by_metric=True,
|
| max_progress_rows=100,
|
| mode="max",
|
| metric="eval_accuracy",
|
| metric_columns=["loss", "eval_loss", "eval_accuracy"])
|
| ) |