Sid01123 commited on
Commit ·
9735e01
1
Parent(s): 54aa369
all the code
Browse files- README.md +53 -0
- agent/__pycache__/agent.cpython-313.pyc +0 -0
- agent/__pycache__/prompt.cpython-313.pyc +0 -0
- agent/__pycache__/utils.cpython-313.pyc +0 -0
- agent/agent.py +676 -0
- agent/prompt.py +64 -0
- agent/utils.py +75 -0
- data/README.md +7 -0
- data/jump-dataset.ipynb +0 -0
- data/jump-similarity.ipynb +0 -0
- model.py +88 -0
- run_app.py +296 -0
- tahoe_model/apply_linear_model.py +93 -0
- tahoe_model/compute_tahoe_deltas.py +26 -0
- tahoe_model/merge_tahoe_vision.py +66 -0
README.md
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
datasets:
|
| 4 |
+
- tahoebio/Tahoe-100M
|
| 5 |
+
tags:
|
| 6 |
+
- tahoe-deepdive
|
| 7 |
+
- hackathon
|
| 8 |
+
- tahoe-100M
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
<div align="center">
|
| 12 |
+
<img src="img/SigSpace.png" alt="SigSpace Logo" width="400"/>
|
| 13 |
+
</div>
|
| 14 |
+
|
| 15 |
+
# SigSpace: An AI Agent for the Tahoe-100M dataset
|
| 16 |
+
This is a submission for the **Tahoe-DeepDive Hackathon 2025**.
|
| 17 |
+
|
| 18 |
+
# Team Name
|
| 19 |
+
SigSpace
|
| 20 |
+
|
| 21 |
+
## Members
|
| 22 |
+
- Ishita Mangla
|
| 23 |
+
- Kuan Pang
|
| 24 |
+
- Giovanni Palla
|
| 25 |
+
- Yanay Rosen
|
| 26 |
+
- Sid Sanghi
|
| 27 |
+
- Yasha Ektefaie
|
| 28 |
+
- Rohit Khurana
|
| 29 |
+
|
| 30 |
+
# Project
|
| 31 |
+
## SigSpace: An AI Agent for the Tahoe-100M dataset
|
| 32 |
+
|
| 33 |
+
## Overview
|
| 34 |
+
We have developed an AI agent that accesses the Tahoe-100M dataset along with publicly available and novel datasets. This agent works to refine and expand the mechanisms of action (MOA) and drug signatures of the perturbations within the Tahoe-100M dataset.
|
| 35 |
+
|
| 36 |
+
## Motivation
|
| 37 |
+
Drug discovery in the age of Large Language Models (LLMs) can be enhanced through agentic workflows that parse diverse sources of unstructured information to synthesize and connect hypotheses across different fields and modalities. However, these models are primarily trained on text data and lack the capacity to effectively interrogate rich biological databases with complex, biologically-motivated queries. In this work, we provide a proof of concept demonstrating how the Tahoe-100M dataset can be integrated with publicly available relevant datasets to expand the hypothesis space for mechanisms of action and drug responses in the perturbations tested in the Tahoe-100M dataset.
|
| 38 |
+
|
| 39 |
+
## Methods
|
| 40 |
+
We have curated new datasets that enhance the description of drugs and cell-lines present in the Tahoe-100M dataset.
|
| 41 |
+
|
| 42 |
+
Specifically:
|
| 43 |
+
- TAHOE-100M: vision scores and metadata.
|
| 44 |
+
- PRISM: We use PRISM drug sensitivity data, which reports the concentration of a compound needed to inhibit 50% of cancer cell viability. Measurements are based on pooled screening of barcoded cell lines and provide a high-throughput assessment of drug response across a large panel of cancer models.
|
| 45 |
+
- NCI60: We use NCI-60 LC50 data, which reports the concentration of a drug that kills 50% of the cells present at the time of drug addition. It is measured across a panel of 60 human cancer cell lines using standardized multi-dose assays.
|
| 46 |
+
- JUMP: We use the JUMP dataset, which captures morphological profiles of cells in response to chemical and genetic perturbations. High-content imaging and automated feature extraction are used to quantify cellular changes, enabling large-scale profiling of perturbation effects across diverse biological contexts.
|
| 47 |
+
- UCE-CXG-EMBEDDING: natural perturbation search using AI virtual cell.
|
| 48 |
+
|
| 49 |
+
## Results
|
| 50 |
+
|
| 51 |
+
We have developed a Gradio application that accesses these databases and performs complex queries, enhancing and grounding the reasoning in real biological measurements.
|
| 52 |
+
|
| 53 |
+
## Discussion
|
agent/__pycache__/agent.cpython-313.pyc
ADDED
|
Binary file (23.8 kB). View file
|
|
|
agent/__pycache__/prompt.cpython-313.pyc
ADDED
|
Binary file (3.7 kB). View file
|
|
|
agent/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (4.22 kB). View file
|
|
|
agent/agent.py
ADDED
|
@@ -0,0 +1,676 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from agent.utils import *
|
| 2 |
+
from agent.prompt import *
|
| 3 |
+
import anndata
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from gradio import ChatMessage
|
| 6 |
+
import re
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import pathlib
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
class SigSpace(Basic_Agent):
|
| 12 |
+
def __init__(self, config_path:str):
|
| 13 |
+
super().__init__(config_path)
|
| 14 |
+
self.conversation = []
|
| 15 |
+
self.system_prompt = Agent_Prompt
|
| 16 |
+
self.conversation = []
|
| 17 |
+
self.conversation.append({"role": "system", "content": self.system_prompt})
|
| 18 |
+
|
| 19 |
+
# initialize data for jump
|
| 20 |
+
path = pathlib.Path("/home/ubuntu/giovanni/code/Tahoe_Hackathon/datasets")
|
| 21 |
+
# jump_path = pathlib.Path("/home/ubuntu/giovanni/data")
|
| 22 |
+
self.jump_tahoe_drug_metadata = pd.read_csv(path/"drug_metadata_inchikey.csv")
|
| 23 |
+
self.jump_similarity_score = pd.read_csv(path/"compound_genetic_perturbation_cosine_similarity_inchikey.csv")
|
| 24 |
+
|
| 25 |
+
# Load PRISM IC50 matrix
|
| 26 |
+
# prism_data_path = pathlib.Path("/home/ubuntu/sid/Hackathon_Tahoe/data")
|
| 27 |
+
self.ic50 = pd.read_csv(path / "Tahoe_PRISM_cell_by_drug_ic50_matrix_named.csv", index_col=0)
|
| 28 |
+
self.ic50.columns = self.ic50.columns.str.lower()
|
| 29 |
+
|
| 30 |
+
# nci60_path = pathlib.Path("/home/ubuntu/ishita/tahoe/")
|
| 31 |
+
self.lc50 = pd.read_csv(path / "filtered_results.csv")
|
| 32 |
+
# Filter out rows where CELL is nan
|
| 33 |
+
self.lc50 = self.lc50[self.lc50['CELL'].notna()]
|
| 34 |
+
|
| 35 |
+
# Load full Tahoe metadata
|
| 36 |
+
# tahoe_path = pathlib.Path("/home/ubuntu/rohit/data")
|
| 37 |
+
self.tahoe_cell_meta = pd.read_csv(path / "cell_line_metadata.csv")
|
| 38 |
+
self.tahoe_drug_meta = pd.read_csv(path / "drug_metadata.csv")
|
| 39 |
+
self.tahoe_vision_scores = anndata.read_h5ad(path / "tahoe_vision_scores.h5ad")
|
| 40 |
+
|
| 41 |
+
# Load PRISM subset of Tahoe metadata
|
| 42 |
+
self.prism_tahoe_cell_meta = pd.read_csv(path / "Tahoe_PRISM_matched_cell_metadata_final.csv")
|
| 43 |
+
self.prism_tahoe_drug_meta = pd.read_csv(path / "Tahoe_PRISM_matched_drug_metadata_final.csv")
|
| 44 |
+
|
| 45 |
+
# Build cell line common name to depmap_id map (strip whitespace and case)
|
| 46 |
+
self.cell_name_to_depmap = {
|
| 47 |
+
row["cell_name"].strip(): row["Cell_ID_DepMap"]
|
| 48 |
+
for _, row in self.prism_tahoe_cell_meta.iterrows()
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
self.cell_name_to_depmap_lc50 = {
|
| 52 |
+
row["clean"].strip(): row["cell_line_name"]
|
| 53 |
+
for _, row in self.lc50.iterrows()
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
self.tahoe_similarity_score = pd.read_csv(path / "in_tahoe_search_result_df.csv")
|
| 57 |
+
self.tahoe_cxg_similarity_score = pd.read_csv(path / "cxg_search_result_df.csv")
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def initialize_conversation(self, message, conversation=None, history=None):
|
| 61 |
+
if conversation is None:
|
| 62 |
+
conversation = []
|
| 63 |
+
|
| 64 |
+
conversation.append({"role": "system", "content" : Agent_Prompt})
|
| 65 |
+
|
| 66 |
+
if history is not None:
|
| 67 |
+
if len(history) == 0:
|
| 68 |
+
conversation = []
|
| 69 |
+
print("clear conversation successfully")
|
| 70 |
+
else:
|
| 71 |
+
for i in range(len(history)):
|
| 72 |
+
if history[i]['role'] == 'user':
|
| 73 |
+
if i-1 >= 0 and history[i-1]['role'] == 'assistant':
|
| 74 |
+
conversation.append(
|
| 75 |
+
{"role": "assistant", "content": history[i-1]['content']})
|
| 76 |
+
conversation.append(
|
| 77 |
+
{"role": "user", "content": history[i]['content']})
|
| 78 |
+
if i == len(history)-1 and history[i]['role'] == 'assistant':
|
| 79 |
+
conversation.append(
|
| 80 |
+
{"role": "assistant", "content": history[i]['content']})
|
| 81 |
+
|
| 82 |
+
conversation.append({"role": "user", "content": message})
|
| 83 |
+
|
| 84 |
+
return conversation
|
| 85 |
+
|
| 86 |
+
def get_similar_disease(self, disease_name, k_value):
|
| 87 |
+
if disease_name != "Alzheimer's":
|
| 88 |
+
return "FAIL"
|
| 89 |
+
return 'Parkinsons Disease'
|
| 90 |
+
|
| 91 |
+
def get_validated_target_jump(self, drug_name):
|
| 92 |
+
print(drug_name)
|
| 93 |
+
try:
|
| 94 |
+
inchikey = self.jump_tahoe_drug_metadata[self.jump_tahoe_drug_metadata.drug.isin([drug_name])]["InChIKey"].values[0]
|
| 95 |
+
similarity_scores = self.jump_similarity_score[self.jump_similarity_score.InChIKey.isin([inchikey])]
|
| 96 |
+
|
| 97 |
+
# Count ORF entries with cosine_similarity > 0.2 and < -0.2
|
| 98 |
+
orf_positive = similarity_scores[(similarity_scores.Genetic_Perturbation == 'ORF') & (similarity_scores.cosine_sim > 0.2)].shape[0]
|
| 99 |
+
orf_negative = similarity_scores[(similarity_scores.Genetic_Perturbation == 'ORF') & (similarity_scores.cosine_sim < -0.2)].shape[0]
|
| 100 |
+
|
| 101 |
+
# Count CRISPR entries with cosine_similarity > 0.2 and < -0.2
|
| 102 |
+
crispr_positive = similarity_scores[(similarity_scores.Genetic_Perturbation == 'CRISPR') & (similarity_scores.cosine_sim > 0.2)].shape[0]
|
| 103 |
+
crispr_negative = similarity_scores[(similarity_scores.Genetic_Perturbation == 'CRISPR') & (similarity_scores.cosine_sim < -0.2)].shape[0]
|
| 104 |
+
|
| 105 |
+
orf_targets = f"ORF: {orf_positive} positive correlations (>0.2), {orf_negative} negative correlations (<-0.2)"
|
| 106 |
+
crispr_targets = f"CRISPR: {crispr_positive} positive correlations (>0.2), {crispr_negative} negative correlations (<-0.2)"
|
| 107 |
+
|
| 108 |
+
orf_crispr_targets = orf_targets + " " +crispr_targets
|
| 109 |
+
|
| 110 |
+
known_targets_from_jump = self.jump_tahoe_drug_metadata[self.jump_tahoe_drug_metadata.drug.isin([drug_name])]["target_list"].values[0]
|
| 111 |
+
known_targets_output = f"The known targets from the JUMP dataset are: {', '.join(known_targets_from_jump.split('|'))}"
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(e)
|
| 114 |
+
return "For the drug {drug_name}, we were not able to find the target in the JUMP dataset."
|
| 115 |
+
|
| 116 |
+
orf_crispr_targets = \
|
| 117 |
+
f"""
|
| 118 |
+
Preturbation description:
|
| 119 |
+
|
| 120 |
+
ORF: The ORF perturbation consists of an overexpression of the target gene.
|
| 121 |
+
CRISPR: The CRISPR perturbation consists of a knockout of the target gene.
|
| 122 |
+
|
| 123 |
+
Considering the drug "{drug_name}", we expect positive correlations with shared CRISPR targets,
|
| 124 |
+
and negative correlations with shared ORF targets.
|
| 125 |
+
|
| 126 |
+
But, the measured correlations are:
|
| 127 |
+
|
| 128 |
+
{orf_crispr_targets}
|
| 129 |
+
|
| 130 |
+
Furthermore, the JUMP dataset has the following known targets for the drug "{drug_name}":
|
| 131 |
+
|
| 132 |
+
{known_targets_output}
|
| 133 |
+
"""
|
| 134 |
+
return orf_crispr_targets
|
| 135 |
+
|
| 136 |
+
def get_similar_drug_effect_in_tahoe(self, cell_line_name: str, drug_name: str):
|
| 137 |
+
"""
|
| 138 |
+
Get similar effect drugs in tahoe based on the drug name and cell line name.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
cell_line_name (str): The name of the cell line.
|
| 142 |
+
drug_name (str): The name of the drug.
|
| 143 |
+
"""
|
| 144 |
+
cell_line_names = self.tahoe_similarity_score["source_cell_line"].unique().tolist()
|
| 145 |
+
drug_names = self.tahoe_similarity_score["source_drug_name"].unique().tolist()
|
| 146 |
+
if cell_line_name not in cell_line_names:
|
| 147 |
+
return "FAIL: Cell line name not found in the dataset. A example: CVCL_0218"
|
| 148 |
+
if drug_name not in drug_names:
|
| 149 |
+
return "FAIL: Drug name not found in the dataset. A example: Daptomycin"
|
| 150 |
+
hits = self.tahoe_similarity_score[
|
| 151 |
+
(self.tahoe_similarity_score["source_cell_line"] == cell_line_name) &
|
| 152 |
+
(self.tahoe_similarity_score["source_drug_name"] == drug_name)
|
| 153 |
+
]
|
| 154 |
+
# sort by distance
|
| 155 |
+
hits = hits.sort_values(by="distance", ascending=True).reset_index(drop=True)
|
| 156 |
+
hits = hits.head(10)
|
| 157 |
+
# keep target_drug_name and target_cell_line
|
| 158 |
+
hits = hits[["target_drug_name", "target_cell_line",]]
|
| 159 |
+
outputs = f"""
|
| 160 |
+
The following drugs have similar effects to the drug you provided:
|
| 161 |
+
hits:
|
| 162 |
+
{hits}
|
| 163 |
+
"""
|
| 164 |
+
return outputs
|
| 165 |
+
|
| 166 |
+
def get_similar_drug_effects_in_cxg(self, cell_line_name: str, drug_name: str):
|
| 167 |
+
"""
|
| 168 |
+
Get similar effect diseases in cxg based on the drug name and cell line name.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
cell_line_name (str): The name of the cell line.
|
| 172 |
+
drug_name (str): The name of the drug.
|
| 173 |
+
"""
|
| 174 |
+
cell_line_names = self.tahoe_cxg_similarity_score["cell_line"].unique().tolist()
|
| 175 |
+
drug_names = self.tahoe_cxg_similarity_score["perturbation_drug_name"].unique().tolist()
|
| 176 |
+
if cell_line_name not in cell_line_names:
|
| 177 |
+
return "FAIL: Cell line name not found in the dataset. A valid example: CVCL_0218"
|
| 178 |
+
if drug_name not in drug_names:
|
| 179 |
+
return "FAIL: Drug name not found in the dataset. A valid example:: Daptomycin"
|
| 180 |
+
hits = self.tahoe_cxg_similarity_score[
|
| 181 |
+
(self.tahoe_cxg_similarity_score["cell_line"] == cell_line_name) &
|
| 182 |
+
(self.tahoe_cxg_similarity_score["perturbation_drug_name"] == drug_name)
|
| 183 |
+
]
|
| 184 |
+
hits = hits.sort_values(by="distance", ascending=True).reset_index(drop=True)
|
| 185 |
+
hits = hits.head(10)
|
| 186 |
+
# keeps cell_type tissue_type and disease
|
| 187 |
+
hits = hits[["cell_type", "tissue_type", "disease"]]
|
| 188 |
+
outputs = f"""
|
| 189 |
+
The following diseases have similar effects to the drug you provided:
|
| 190 |
+
hits:
|
| 191 |
+
{hits}
|
| 192 |
+
"""
|
| 193 |
+
return outputs
|
| 194 |
+
|
| 195 |
+
def get_ic50_prism(self, drug_name: str, cell_line_name: str):
|
| 196 |
+
drug_name_lower = drug_name.strip().lower()
|
| 197 |
+
cell_line_key = cell_line_name.strip()
|
| 198 |
+
|
| 199 |
+
if cell_line_key not in self.cell_name_to_depmap:
|
| 200 |
+
print(f"Cell line name '{cell_line_key}' not found for PRISM data")
|
| 201 |
+
return f"FAIL: Cell line name '{cell_line_key}' not found for PRISM data"
|
| 202 |
+
|
| 203 |
+
depmap_id = self.cell_name_to_depmap[cell_line_key]
|
| 204 |
+
|
| 205 |
+
if drug_name_lower not in self.ic50.columns:
|
| 206 |
+
print(f"Drug name '{drug_name}' not found in IC50 matrix columns.")
|
| 207 |
+
return f"FAIL: Drug name '{drug_name}' not found in IC50 matrix columns."
|
| 208 |
+
|
| 209 |
+
try:
|
| 210 |
+
ic50_val = self.ic50.loc[depmap_id, drug_name_lower]
|
| 211 |
+
if pd.isna(ic50_val):
|
| 212 |
+
print(f"FAIL: IC50 value is missing for '{drug_name}' in cell line '{cell_line_name}' (DepMap ID: {depmap_id}).")
|
| 213 |
+
return f"FAIL: IC50 value is missing for '{drug_name}' in cell line '{cell_line_name}' (DepMap ID: {depmap_id})."
|
| 214 |
+
|
| 215 |
+
return (
|
| 216 |
+
f"The IC50 value of {ic50_val:.4f} corresponds to the log10-transformed micromolar concentration "
|
| 217 |
+
f"at which {drug_name} inhibits 50% of viability in the {cell_line_name} cell line "
|
| 218 |
+
f"(DepMap ID: {depmap_id}).\n\n"
|
| 219 |
+
"This value comes from the PRISM Repurposing Secondary Screen, which exposes pooled barcoded cell lines "
|
| 220 |
+
"to drug treatment for 5 days and infers viability from barcode abundance using sequencing.\n\n"
|
| 221 |
+
"The secondary screen includes higher-confidence compound–cell line pairs with improved replicability "
|
| 222 |
+
"compared to the primary screen.\n\n"
|
| 223 |
+
"Lower IC50 values indicate greater sensitivity of the cell line to the drug."
|
| 224 |
+
)
|
| 225 |
+
except KeyError as e:
|
| 226 |
+
print(f"Combination not found: {e}")
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def clean_cell_line_name(self, name):
|
| 231 |
+
"""
|
| 232 |
+
Standardize cell line names for comparison by:
|
| 233 |
+
1. Converting to string (handles any non-string values)
|
| 234 |
+
2. Converting to uppercase
|
| 235 |
+
3. Removing all non-alphanumeric characters
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
name: Cell line name (string or other type)
|
| 239 |
+
|
| 240 |
+
Returns:
|
| 241 |
+
Cleaned string with only uppercase letters and numbers
|
| 242 |
+
"""
|
| 243 |
+
return re.sub(r"[^A-Z0-9]", "", str(name).upper())
|
| 244 |
+
|
| 245 |
+
def get_lc50_nci60(self, drug_name: str, cell_line_name: str):
|
| 246 |
+
cell_line_name = cell_line_name.upper()
|
| 247 |
+
cell_line_key = self.clean_cell_line_name(cell_line_name)
|
| 248 |
+
|
| 249 |
+
if cell_line_key not in self.cell_name_to_depmap_lc50:
|
| 250 |
+
print(f"Cell line name '{cell_line_key}' not found for NCI60 data")
|
| 251 |
+
return None
|
| 252 |
+
depmap_id = self.cell_name_to_depmap_lc50[cell_line_key]
|
| 253 |
+
print ("Depmap_id", depmap_id)
|
| 254 |
+
|
| 255 |
+
# Find the drug in NCI60 dataset
|
| 256 |
+
# Since drugs are in uppercase in the list, convert search term to uppercase
|
| 257 |
+
drug_name_upper = drug_name.strip().upper()
|
| 258 |
+
|
| 259 |
+
# Filter rows where the drug name is in the drug column
|
| 260 |
+
# This assumes drugs in each row are comma-separated or in a format that can be searched
|
| 261 |
+
matching_row = self.lc50[self.lc50['drug'].str.contains(drug_name_upper, na=False)]
|
| 262 |
+
print ("Matching row", matching_row)
|
| 263 |
+
if matching_row.empty:
|
| 264 |
+
print(f"Drug name '{drug_name}' not found in NCI60 dataset.")
|
| 265 |
+
return None
|
| 266 |
+
|
| 267 |
+
if matching_row.empty:
|
| 268 |
+
raise ValueError(f"Multiple matches found for drug '{drug_name}' in NCI60 dataset.")
|
| 269 |
+
|
| 270 |
+
print ("Matching row", matching_row)
|
| 271 |
+
# Get the LC50 value from the matching row
|
| 272 |
+
lc50_val = matching_row.iloc[0]['NLOGLC50']
|
| 273 |
+
lconc_val = matching_row.iloc[0]['LCONC']
|
| 274 |
+
|
| 275 |
+
if pd.isna(lc50_val):
|
| 276 |
+
return "LC50 value is missing for '{drug_name}' in cell line '{cell_line_name}' (depmap_id: {depmap_id})."
|
| 277 |
+
|
| 278 |
+
lc50_output = \
|
| 279 |
+
f"""
|
| 280 |
+
The LC50 value of {lc50_val} represents -log10(LC50), the negative base-10 logarithm of the molar concentration that inhibits 50% of cell growth.
|
| 281 |
+
|
| 282 |
+
Higher LC50 values therefore indicate greater drug potency.
|
| 283 |
+
|
| 284 |
+
The LCONC value of {lconc_val} denotes the maximum log10 molar concentration tested in the dilution series—for example, LCONC = -4 corresponds to 10^-4 M.
|
| 285 |
+
|
| 286 |
+
Both metrics come from the NCI-60 drug screen, which applies a standardized 48-hour exposure assay across all compound–cell-line pairs."
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
return lc50_output
|
| 290 |
+
|
| 291 |
+
def load_gene_sets_file(self, file_path):
|
| 292 |
+
"""
|
| 293 |
+
Load gene sets from a tab-delimited file where the first column is the gene set name
|
| 294 |
+
and the remaining columns are gene symbols.
|
| 295 |
+
|
| 296 |
+
Parameters:
|
| 297 |
+
-----------
|
| 298 |
+
file_path : str
|
| 299 |
+
Path to the gene sets file
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
--------
|
| 303 |
+
dict
|
| 304 |
+
Dictionary mapping gene set names to lists of genes
|
| 305 |
+
"""
|
| 306 |
+
gene_sets = {}
|
| 307 |
+
with open(file_path, 'r') as file:
|
| 308 |
+
for line in file:
|
| 309 |
+
parts = line.strip().split('\t')
|
| 310 |
+
if parts:
|
| 311 |
+
set_name = parts[0]
|
| 312 |
+
genes = [gene for gene in parts[1:] if gene] # Filter out empty strings
|
| 313 |
+
gene_sets[set_name] = genes
|
| 314 |
+
return gene_sets
|
| 315 |
+
|
| 316 |
+
def get_genes_for_set(self, set_name):
|
| 317 |
+
"""
|
| 318 |
+
Get the list of genes for a specific gene set.
|
| 319 |
+
|
| 320 |
+
Parameters:
|
| 321 |
+
-----------
|
| 322 |
+
set_name : str
|
| 323 |
+
Name of the gene set to query
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
--------
|
| 327 |
+
list
|
| 328 |
+
List of genes in the gene set, or empty list if set not found
|
| 329 |
+
"""
|
| 330 |
+
if not hasattr(self, 'gene_sets'):
|
| 331 |
+
# Load the gene sets file if it hasn't been loaded yet
|
| 332 |
+
self.gene_sets = self.load_gene_sets_file('/home/ubuntu/ishita/msigdb_all_sigs_human_symbols.txt')
|
| 333 |
+
|
| 334 |
+
return self.gene_sets.get(set_name, [])
|
| 335 |
+
|
| 336 |
+
def rank_vision_scores(self, drug_name: str, cell_line_name: str, k_value: int):
|
| 337 |
+
self.tahoe_vision_scores.X = (self.tahoe_vision_scores.X - np.mean(self.tahoe_vision_scores.X, axis = 0)) / np.std(self.tahoe_vision_scores.X, axis = 0)
|
| 338 |
+
|
| 339 |
+
# subset to the drug / cell line at the highest tested concentration
|
| 340 |
+
filt = (
|
| 341 |
+
(self.tahoe_vision_scores.obs["Cell_Name_Vevo"] == cell_line_name)
|
| 342 |
+
& (self.tahoe_vision_scores.obs["drug"] == drug_name)
|
| 343 |
+
)
|
| 344 |
+
filtered_scores = self.tahoe_vision_scores[filt]
|
| 345 |
+
if filtered_scores.n_obs == 0:
|
| 346 |
+
return "VISION scores not found for this drug–cell-line combination."
|
| 347 |
+
|
| 348 |
+
filtered_scores = filtered_scores[
|
| 349 |
+
filtered_scores.obs["concentration"] == filtered_scores.obs["concentration"].max()
|
| 350 |
+
]
|
| 351 |
+
|
| 352 |
+
# pick top-|score| gene sets
|
| 353 |
+
top_idx = np.argsort(-np.abs(filtered_scores.X[0]))[:k_value]
|
| 354 |
+
gene_sets = filtered_scores.var.index[top_idx].tolist()
|
| 355 |
+
scores = filtered_scores.X[0, top_idx].tolist()
|
| 356 |
+
|
| 357 |
+
# build the narrative
|
| 358 |
+
header = (
|
| 359 |
+
"VISION scores are single-cell gene-set enrichment values computed by the "
|
| 360 |
+
"VISION algorithm (DeTomaso & Yosef 2021). Positive scores indicate relative "
|
| 361 |
+
"up-regulation of the gene set in the queried condition; negative scores indicate "
|
| 362 |
+
"down-regulation.\n"
|
| 363 |
+
)
|
| 364 |
+
lines = []
|
| 365 |
+
for gs, val in zip(gene_sets, scores):
|
| 366 |
+
gs_name = gs.replace("gs_", "")
|
| 367 |
+
genes = self.get_genes_for_set(gs_name)
|
| 368 |
+
direction = "up-regulated" if val > 0 else "down-regulated" if val < 0 else "not changed"
|
| 369 |
+
lines.append(f"{gs} has gene set {genes} : {direction} (VISION score = {val:.3f})")
|
| 370 |
+
|
| 371 |
+
return header + "\n".join(lines)
|
| 372 |
+
|
| 373 |
+
def obtain_moa(self, drug_name: str):
|
| 374 |
+
row = self.tahoe_drug_meta[self.tahoe_drug_meta["drug"] == drug_name]
|
| 375 |
+
|
| 376 |
+
if row.empty:
|
| 377 |
+
return "MOA annotation not found for this drug."
|
| 378 |
+
|
| 379 |
+
moa_broad = row["moa-broad"].values[0]
|
| 380 |
+
moa_fine = row["moa-fine"].values[0]
|
| 381 |
+
|
| 382 |
+
return (
|
| 383 |
+
f"Broad MOA: {moa_broad}; "
|
| 384 |
+
f"Fine MOA: {moa_fine}. "
|
| 385 |
+
"Fine-grained mechanism of action (MOA) annotation for the drug, "
|
| 386 |
+
"specifying the biological process or molecular target affected. "
|
| 387 |
+
"Derived from MedChemExpress and curated with GPT-based annotations."
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def obtain_gene_targets(self, drug_name: str):
|
| 391 |
+
row = self.tahoe_drug_meta[self.tahoe_drug_meta["drug"] == drug_name]
|
| 392 |
+
if row.empty:
|
| 393 |
+
return "Gene targets not found for this drug."
|
| 394 |
+
|
| 395 |
+
targets = row["targets"].values[0]
|
| 396 |
+
|
| 397 |
+
# Convert a stringified list/dict to a Python object, if necessary.
|
| 398 |
+
if isinstance(targets, str):
|
| 399 |
+
try:
|
| 400 |
+
targets = eval(targets)
|
| 401 |
+
except Exception: # fall back to treating it as a single ID
|
| 402 |
+
targets = [targets]
|
| 403 |
+
|
| 404 |
+
return (
|
| 405 |
+
f"Gene target token IDs: {targets}. "
|
| 406 |
+
"Gene identifiers (integer token IDs) corresponding to each gene with non-zero expression in the cell."
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
def obtain_cell_line_data(self, cell_line_name: str):
|
| 410 |
+
row = self.tahoe_cell_meta[self.tahoe_cell_meta["cell_name"] == cell_line_name]
|
| 411 |
+
|
| 412 |
+
if row.empty:
|
| 413 |
+
return "Cell-line metadata not found for this cell line."
|
| 414 |
+
|
| 415 |
+
organ = row["Organ"].values[0]
|
| 416 |
+
driver_gene_symbol = row["Driver_Gene_Symbol"].values[0]
|
| 417 |
+
driver_varzyg = row["Driver_VarZyg"].values[0]
|
| 418 |
+
driver_vartype = row["Driver_VarType"].values[0]
|
| 419 |
+
driver_proteffect = row["Driver_ProtEffect_or_CdnaEffect"].values[0]
|
| 420 |
+
driver_mech_inferdm = row["Driver_Mech_InferDM"].values[0]
|
| 421 |
+
driver_genetype_dm = row["Driver_GeneType_DM"].values[0]
|
| 422 |
+
|
| 423 |
+
return (
|
| 424 |
+
f"Organ: {organ}; "
|
| 425 |
+
f"Driver_Gene_Symbol: {driver_gene_symbol}; "
|
| 426 |
+
f"Driver_VarZyg: {driver_varzyg}; "
|
| 427 |
+
f"Driver_VarType: {driver_vartype}; "
|
| 428 |
+
f"Driver_ProtEffect_or_CdnaEffect: {driver_proteffect}; "
|
| 429 |
+
f"Driver_Mech_InferDM: {driver_mech_inferdm}; "
|
| 430 |
+
f"Driver_GeneType_DM: {driver_genetype_dm}. "
|
| 431 |
+
"Organ = tissue or organ of origin for the cell line (e.g., Lung), used to interpret lineage-specific responses. "
|
| 432 |
+
"Driver_Gene_Symbol = HGNC-approved symbol of a driver gene with functional alterations in this cell line. "
|
| 433 |
+
"Driver_VarZyg = zygosity of the driver variant (Hom = homozygous, Het = heterozygous). "
|
| 434 |
+
"Driver_VarType = type of genetic alteration (e.g., Missense, Frameshift, Stopgain). "
|
| 435 |
+
"Driver_ProtEffect_or_CdnaEffect = precise protein or cDNA-level annotation of the mutation (e.g., p.G12S). "
|
| 436 |
+
"Driver_Mech_InferDM = inferred functional mechanism (LoF = loss-of-function, GoF = gain-of-function). "
|
| 437 |
+
"Driver_GeneType_DM = classification of the driver gene as an Oncogene or Suppressor."
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
def run_gradio_chat(self, message: str,
|
| 441 |
+
history: list,
|
| 442 |
+
temperature: float,
|
| 443 |
+
max_new_tokens: int,
|
| 444 |
+
max_token: int,
|
| 445 |
+
call_agent: bool,
|
| 446 |
+
conversation: gr.State,
|
| 447 |
+
max_round: int = 20,
|
| 448 |
+
seed: int = None,
|
| 449 |
+
call_agent_level: int = 0,
|
| 450 |
+
sub_agent_task: str = None):
|
| 451 |
+
|
| 452 |
+
print("\033[1;32;40mstart\033[0m")
|
| 453 |
+
print("len(message)", len(message))
|
| 454 |
+
|
| 455 |
+
if len(message) <= 10:
|
| 456 |
+
yield "Hi, I am Agent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
|
| 457 |
+
return "Please provide a valid message."
|
| 458 |
+
|
| 459 |
+
outputs = []
|
| 460 |
+
outputs_str = ''
|
| 461 |
+
last_outputs = []
|
| 462 |
+
|
| 463 |
+
conversation = self.initialize_conversation(
|
| 464 |
+
message,
|
| 465 |
+
conversation=conversation,
|
| 466 |
+
history=history)
|
| 467 |
+
|
| 468 |
+
history = []
|
| 469 |
+
|
| 470 |
+
next_round = True
|
| 471 |
+
function_call_messages = []
|
| 472 |
+
current_round = 0
|
| 473 |
+
enable_summary = False
|
| 474 |
+
last_status = {} # for summary
|
| 475 |
+
token_overflow = False
|
| 476 |
+
# if self.enable_checker:
|
| 477 |
+
# checker = ReasoningTraceChecker(
|
| 478 |
+
# message, conversation, init_index=len(conversation))
|
| 479 |
+
|
| 480 |
+
# try:
|
| 481 |
+
self.conversation.append({"role": "user", "content": message})
|
| 482 |
+
while next_round and current_round < max_round:
|
| 483 |
+
current_round += 1
|
| 484 |
+
|
| 485 |
+
response = self.llm_infer(self.conversation)
|
| 486 |
+
self.conversation.append({"role": "system", "content": response})
|
| 487 |
+
tool_called = False
|
| 488 |
+
print(response)
|
| 489 |
+
# import pdb; pdb.set_trace()
|
| 490 |
+
|
| 491 |
+
if 'Tool-call:' in response:
|
| 492 |
+
match = re.search(r"Tool-call:\s*(.*)", response, re.DOTALL)
|
| 493 |
+
response_text = match.group(1).strip()
|
| 494 |
+
if "None" not in response_text and response_text.replace('-', '').rstrip().replace('FINISHED', '').rstrip():
|
| 495 |
+
history.append(ChatMessage(
|
| 496 |
+
role="assistant", content=f"{response.replace('FINISHED', '').split('</think>')[1]}"))
|
| 497 |
+
yield history
|
| 498 |
+
|
| 499 |
+
tool_called = True
|
| 500 |
+
print(response_text)
|
| 501 |
+
if "FAIL" in response_text:
|
| 502 |
+
self.conversation.append({"role": "system", "content": tool_response})
|
| 503 |
+
history.append(
|
| 504 |
+
ChatMessage(role="assistant", content=f"Response from tool FAILED ")
|
| 505 |
+
)
|
| 506 |
+
next_round = False
|
| 507 |
+
yield history
|
| 508 |
+
else:
|
| 509 |
+
tool_call_text = response_text
|
| 510 |
+
if ';' in tool_call_text:
|
| 511 |
+
tool_calls = [i.replace('\n', '').rstrip('-').replace('FINISHED', '').replace('Response:', '') for i in tool_call_text.split(';') if i]
|
| 512 |
+
elif '\n' in tool_call_text:
|
| 513 |
+
tool_calls = [i.replace('\n', '').rstrip('-').replace('FINISHED', '').replace('Response:', '') for i in tool_call_text.split('\n') if i]
|
| 514 |
+
else:
|
| 515 |
+
tool_calls = [tool_call_text]
|
| 516 |
+
|
| 517 |
+
tool_calls = [i.rstrip('-') for i in tool_calls if i]
|
| 518 |
+
|
| 519 |
+
for call in tool_calls:
|
| 520 |
+
print(f"\033[1;34;40mCalling this command now {call}\033[0m")
|
| 521 |
+
tool_response = str(eval(call))
|
| 522 |
+
self.conversation.append({"role": "system", "content": tool_response})
|
| 523 |
+
history.append(
|
| 524 |
+
ChatMessage(role="assistant", content=f"Response from tool: {tool_response}")
|
| 525 |
+
)
|
| 526 |
+
print(f"\033[1;34;40mGot this response {tool_response}\033[0m")
|
| 527 |
+
yield history
|
| 528 |
+
else:
|
| 529 |
+
history.append(
|
| 530 |
+
ChatMessage(role="assistant", content=f"{response}")
|
| 531 |
+
)
|
| 532 |
+
yield history
|
| 533 |
+
|
| 534 |
+
elif 'Response:' in response or tool_called is False:
|
| 535 |
+
match = re.search(r"Response:\s*(.*)", response, re.DOTALL)
|
| 536 |
+
response_text = match.group(1).strip().replace('Tool-call: None', '')
|
| 537 |
+
print(f"\033[1;33;40mresponse text: {response_text}\033[0m")
|
| 538 |
+
history.append(
|
| 539 |
+
ChatMessage(
|
| 540 |
+
role="assistant", content=f"{response_text.replace('FINISHED', '')}")
|
| 541 |
+
)
|
| 542 |
+
yield history
|
| 543 |
+
|
| 544 |
+
if 'FINISHED' in response and tool_called is False:
|
| 545 |
+
next_round = False
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
# if len(last_outputs) > 0:
|
| 553 |
+
# function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
|
| 554 |
+
# last_outputs, return_message=True,
|
| 555 |
+
# existing_tools_prompt=picked_tools_prompt,
|
| 556 |
+
# message_for_call_agent=message,
|
| 557 |
+
# call_agent=call_agent,
|
| 558 |
+
# call_agent_level=call_agent_level,
|
| 559 |
+
# temperature=temperature)
|
| 560 |
+
# history.extend(current_gradio_history)
|
| 561 |
+
# if special_tool_call == 'Finish':
|
| 562 |
+
# yield history
|
| 563 |
+
# next_round = False
|
| 564 |
+
# conversation.extend(function_call_messages)
|
| 565 |
+
# return function_call_messages[0]['content']
|
| 566 |
+
# elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
|
| 567 |
+
# history.append(
|
| 568 |
+
# ChatMessage(role="assistant", content=history[-1].content))
|
| 569 |
+
# yield history
|
| 570 |
+
# next_round = False
|
| 571 |
+
# return history[-1].content
|
| 572 |
+
# if (self.enable_summary or token_overflow) and not call_agent:
|
| 573 |
+
# if token_overflow:
|
| 574 |
+
# print("token_overflow, using summary")
|
| 575 |
+
# enable_summary = True
|
| 576 |
+
# last_status = self.function_result_summary(
|
| 577 |
+
# conversation, status=last_status,
|
| 578 |
+
# enable_summary=enable_summary)
|
| 579 |
+
# if function_call_messages is not None:
|
| 580 |
+
# conversation.extend(function_call_messages)
|
| 581 |
+
# formated_md_function_call_messages = tool_result_format(
|
| 582 |
+
# function_call_messages)
|
| 583 |
+
# yield history
|
| 584 |
+
# else:
|
| 585 |
+
# next_round = False
|
| 586 |
+
# conversation.extend(
|
| 587 |
+
# [{"role": "assistant", "content": ''.join(last_outputs)}])
|
| 588 |
+
# return ''.join(last_outputs).replace("</s>", "")
|
| 589 |
+
# # if self.enable_checker:
|
| 590 |
+
# # good_status, wrong_info = checker.check_conversation()
|
| 591 |
+
# # if not good_status:
|
| 592 |
+
# # next_round = False
|
| 593 |
+
# # print("Internal error in reasoning: " + wrong_info)
|
| 594 |
+
# # break
|
| 595 |
+
# last_outputs = []
|
| 596 |
+
# last_outputs_str, token_overflow = self.llm_infer(
|
| 597 |
+
# messages=conversation,
|
| 598 |
+
# temperature=temperature,
|
| 599 |
+
# tools=picked_tools_prompt,
|
| 600 |
+
# skip_special_tokens=False,
|
| 601 |
+
# max_new_tokens=max_new_tokens,
|
| 602 |
+
# max_token=max_token,
|
| 603 |
+
# seed=seed,
|
| 604 |
+
# check_token_status=True)
|
| 605 |
+
# last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
|
| 606 |
+
# for each in history:
|
| 607 |
+
# if each.metadata is not None:
|
| 608 |
+
# each.metadata['status'] = 'done'
|
| 609 |
+
# if '[FinalAnswer]' in last_thought:
|
| 610 |
+
# final_thought, final_answer = last_thought.split(
|
| 611 |
+
# '[FinalAnswer]')
|
| 612 |
+
# history.append(
|
| 613 |
+
# ChatMessage(role="assistant",
|
| 614 |
+
# content=final_thought.strip())
|
| 615 |
+
# )
|
| 616 |
+
# yield history
|
| 617 |
+
# history.append(
|
| 618 |
+
# ChatMessage(
|
| 619 |
+
# role="assistant", content="**Answer**:\n"+final_answer.strip())
|
| 620 |
+
# )
|
| 621 |
+
# yield history
|
| 622 |
+
# else:
|
| 623 |
+
# history.append(ChatMessage(
|
| 624 |
+
# role="assistant", content=last_thought))
|
| 625 |
+
# yield history
|
| 626 |
+
|
| 627 |
+
# last_outputs.append(last_outputs_str)
|
| 628 |
+
|
| 629 |
+
# if self.force_finish:
|
| 630 |
+
# last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
| 631 |
+
# conversation, temperature, max_new_tokens, max_token, return_full_thought=True)
|
| 632 |
+
# for each in history:
|
| 633 |
+
# if each.metadata is not None:
|
| 634 |
+
# each.metadata['status'] = 'done'
|
| 635 |
+
|
| 636 |
+
# final_thought, final_answer = last_outputs_str.split('[FinalAnswer]')
|
| 637 |
+
# history.append(
|
| 638 |
+
# ChatMessage(role="assistant",
|
| 639 |
+
# content=final_thought.strip())
|
| 640 |
+
# )
|
| 641 |
+
# yield history
|
| 642 |
+
# history.append(
|
| 643 |
+
# ChatMessage(
|
| 644 |
+
# role="assistant", content="**Answer**:\n"+final_answer.strip())
|
| 645 |
+
# )
|
| 646 |
+
# yield history
|
| 647 |
+
# else:
|
| 648 |
+
# yield "The number of rounds exceeds the maximum limit!"
|
| 649 |
+
|
| 650 |
+
# except Exception as e:
|
| 651 |
+
# print(f"Error: {e}")
|
| 652 |
+
# if self.force_finish:
|
| 653 |
+
# last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
|
| 654 |
+
# conversation,
|
| 655 |
+
# temperature,
|
| 656 |
+
# max_new_tokens,
|
| 657 |
+
# max_token,
|
| 658 |
+
# return_full_thought=True)
|
| 659 |
+
# for each in history:
|
| 660 |
+
# if each.metadata is not None:
|
| 661 |
+
# each.metadata['status'] = 'done'
|
| 662 |
+
|
| 663 |
+
# final_thought, final_answer = last_outputs_str.split(
|
| 664 |
+
# '[FinalAnswer]')
|
| 665 |
+
# history.append(
|
| 666 |
+
# ChatMessage(role="assistant",
|
| 667 |
+
# content=final_thought.strip())
|
| 668 |
+
# )
|
| 669 |
+
# yield history
|
| 670 |
+
# history.append(
|
| 671 |
+
# ChatMessage(
|
| 672 |
+
# role="assistant", content="**Answer**:\n"+final_answer.strip())
|
| 673 |
+
# )
|
| 674 |
+
# yield history
|
| 675 |
+
# else:
|
| 676 |
+
# return None
|
agent/prompt.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Agent_Prompt = """
|
| 2 |
+
|
| 3 |
+
You are an assistant who is helping the user identify novel gene sets for a particular disease. All your responses must be in the following format.
|
| 4 |
+
If you don't use a tool then don't include a tool-call, if you don't need to respond to the user and instead want to solely call a tool then don't include a response.
|
| 5 |
+
Return FINISHED at the end of a response if you have responded to the user query. DO NOT hallucinate, guess, or ASSUME tool responses. If you need to call multiple tools separate the tool-calls with a semi-colon; :
|
| 6 |
+
|
| 7 |
+
Reasoning:
|
| 8 |
+
|
| 9 |
+
[Your reasoning goes here]
|
| 10 |
+
|
| 11 |
+
Response:
|
| 12 |
+
|
| 13 |
+
[Your response goes here, if necessary]
|
| 14 |
+
|
| 15 |
+
Tool-call:
|
| 16 |
+
|
| 17 |
+
[Tool call goes here, if necessary]
|
| 18 |
+
|
| 19 |
+
------------------------------------------------
|
| 20 |
+
|
| 21 |
+
The tools you have in your disposal are:
|
| 22 |
+
|
| 23 |
+
(1) A tool which can tell you the k-most diseases that are similar to your query disease.
|
| 24 |
+
|
| 25 |
+
The tool call for this agent is: "self.get_similar_disease(disease_name, k_value)" where disease_name must be a string and k_value must be an integer. The output of this tool is a list of disease names.
|
| 26 |
+
|
| 27 |
+
(2) A tool which can retrieve the gene targets validated from JUMP-CP dataset.
|
| 28 |
+
|
| 29 |
+
The tool call for this agent is: "self.get_validated_target_jump(drug_name)" where drug_name must be a string. The output of this tool is a list of gene targets.
|
| 30 |
+
|
| 31 |
+
(3) A tool which can retrieve an IC50 value for a drug and cell line from the PRISM Repurposing 20Q2 dataset.
|
| 32 |
+
|
| 33 |
+
The tool call for this agent is: "self.get_ic50_prism(drug_name, cell_line)" where drug_name and cell_line must be strings. The output of this tool is scalar IC50 floating point value. These are not keyword arguments.
|
| 34 |
+
|
| 35 |
+
(4) A tool which can retrieve gene-set expression scores from the Tahoe-100M dataset.
|
| 36 |
+
|
| 37 |
+
The tool call for this agent is "self.rank_vision_scores(drug_name, cell_line, k_value)" where drug_name and cell_line must be strings and k_value must be an integer. These are not keyword arguments. The output of this tool is a list of tuples, where each tuple contains a gene-set name and its corresponding expression score.
|
| 38 |
+
|
| 39 |
+
(5) A tool which can obtain the mechanism of action for a drug from the Tahoe-100M dataset.
|
| 40 |
+
|
| 41 |
+
The tool call for this agent is "self.obtain_moa(drug_name)" where drug_name must be a string. This is not a keyword argument. The output of this tool is dictionary that contains a broad mechanism of action and a more specific mechanism of action.
|
| 42 |
+
|
| 43 |
+
(6) A tool which can retrieve the gene targets for a drug from the Tahoe-100M dataset.
|
| 44 |
+
|
| 45 |
+
The tool call for this agent is: "self.obtain_gene_targets(drug_name)" where drug_name must be a string. This is not a keyword argument. The output of this tool is a list of gene symbols representing the known molecular targets of the compound.
|
| 46 |
+
|
| 47 |
+
(7) A tool which can retrieve the cell line metadata from the Tahoe-100M dataset.
|
| 48 |
+
|
| 49 |
+
The tool call for this agent is: "self.obtain_cell_line_data(cell_line_name)" where cell_line_name must be a string. This is not a keyword argument. The output of this tool is a dictionary containing information about key driver mutations for each cell line.
|
| 50 |
+
|
| 51 |
+
(8) A tool which can retrieve the LC50 value for a drug and cell line from the NCI-60 dataset.
|
| 52 |
+
|
| 53 |
+
The tool call for this agent is: "self.get_lc50_nci60(drug_name, cell_line_name)" where drug_name and cell_line_name must be strings. These are not keyword arguments.The output of this tool is a tuple of (LC50, LCONC). The LC50 value is in log10 scale and the LCONC is a scalar value that is in log10 scale. It is thelog of the highest concentration tested.
|
| 54 |
+
|
| 55 |
+
(9) A tool which searches for the similar perturbation effect within the Tahoe dataset.
|
| 56 |
+
|
| 57 |
+
The tool call for this agent is: "self.get_similar_drug_effect_in_tahoe(cell_line_name, drug_name)" where cell_line_name and drug_name must be strings. These are not keyword arguments. The output of this tool is a string of dataframe that tells you about what other **drugs** in tahoe have similar perturbation effect on the cell line.
|
| 58 |
+
|
| 59 |
+
(10) A tool that does natural perturbation search in Cellxgene database for similar drug effect.
|
| 60 |
+
|
| 61 |
+
The tool call for this agent is: "self.get_similar_drug_effects_in_cxg(cell_line_name, drug_name)" where cell_line_name and drug_name must be strings. These are not keyword arguments. The output of this tool is a string of dataframe that tells you about what **disease** and cell types in cellxgene have similar perturbation effect on the cell line.
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
"""
|
agent/utils.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import AzureOpenAI, OpenAI
|
| 2 |
+
import yaml
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Basic_Agent():
|
| 6 |
+
|
| 7 |
+
def __init__(self, config):
|
| 8 |
+
self.config = self.load_config(config)
|
| 9 |
+
self.openai_api_key = self.config['openai_api_key']
|
| 10 |
+
if 'open_api_base' in self.config:
|
| 11 |
+
self.open_api_base = self.config['open_api_base']
|
| 12 |
+
self.azure_openai_api_key = self.config['azure_openai_api_key']
|
| 13 |
+
self.azure_openai_endpoint = self.config['azure_openai_endpoint']
|
| 14 |
+
self.openai_backend = self.config['openai_backend']
|
| 15 |
+
# self.pqapi_token = self.config['pqapi_token']
|
| 16 |
+
# os.environ['PQA_API_TOKEN'] = self.pqapi_token
|
| 17 |
+
|
| 18 |
+
def load_config(self,config_file):
|
| 19 |
+
with open(config_file, 'r') as file:
|
| 20 |
+
return yaml.safe_load(file)
|
| 21 |
+
|
| 22 |
+
def llm_infer(self, conversation, temp = 0.000000001, max_tokens = 1000, image = None, role = None):
|
| 23 |
+
|
| 24 |
+
while True:
|
| 25 |
+
|
| 26 |
+
if self.openai_backend == 'azure':
|
| 27 |
+
client = AzureOpenAI(
|
| 28 |
+
azure_endpoint = self.azure_openai_endpoint,
|
| 29 |
+
api_key=self.azure_openai_api_key,
|
| 30 |
+
api_version="2024-05-01-preview")
|
| 31 |
+
|
| 32 |
+
response = client.chat.completions.create(
|
| 33 |
+
model='gpt-4o',
|
| 34 |
+
messages = conversation,
|
| 35 |
+
temperature=temp,
|
| 36 |
+
max_tokens=max_tokens,
|
| 37 |
+
)
|
| 38 |
+
elif self.openai_backend == 'openai':
|
| 39 |
+
client = OpenAI(
|
| 40 |
+
api_key=self.openai_api_key
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
response = client.chat.completions.create(
|
| 44 |
+
model='gpt-4o',
|
| 45 |
+
messages=conversation,
|
| 46 |
+
temperature=temp,
|
| 47 |
+
max_tokens=max_tokens,
|
| 48 |
+
)
|
| 49 |
+
elif self.openai_backend == 'lambda':
|
| 50 |
+
|
| 51 |
+
client = OpenAI(api_key = self.openai_api_key,
|
| 52 |
+
base_url = self.open_api_base)
|
| 53 |
+
|
| 54 |
+
model = "deepseek-r1-671b"
|
| 55 |
+
response = client.chat.completions.create(
|
| 56 |
+
model = model,
|
| 57 |
+
messages = conversation)
|
| 58 |
+
else:
|
| 59 |
+
raise ValueError(f"Invalid openai_backend: {self.openai_backend}")
|
| 60 |
+
|
| 61 |
+
if "I'm sorry, I can't assist with that" in response.choices[0].message.content or "I'm unable to view the image" in response.choices[0].message.content or "I'm unable to provide a definitive answer" in response.choices[0].message.content:
|
| 62 |
+
print("Failed to generate response, trying again")
|
| 63 |
+
continue
|
| 64 |
+
else:
|
| 65 |
+
response = response.choices[0].message.content
|
| 66 |
+
return response
|
| 67 |
+
|
| 68 |
+
def run_function(self, output):
|
| 69 |
+
try:
|
| 70 |
+
tool_call = output.split('Tool-call:')[-1].rstrip().replace('\n', '')
|
| 71 |
+
res = eval(tool_call)
|
| 72 |
+
return res
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Error in parsing tool call in {output} got this error {e}")
|
| 75 |
+
import pdb; pdb.set_trace()
|
data/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Data Download
|
| 2 |
+
|
| 3 |
+
Download jump signatures
|
| 4 |
+
|
| 5 |
+
```
|
| 6 |
+
wget https://cellpainting-gallery.s3.amazonaws.com/cpg0016-jump-assembled/source_all/workspace/profiles/jump-profiling-recipe_2024_a917fa7/COMPOUND/profiles_var_mad_int_featselect_harmony/profiles_var_mad_int.parquet
|
| 7 |
+
```
|
data/jump-dataset.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
data/jump-similarity.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
model.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import anndata
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.utils.data import DataLoader, TensorDataset
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# Generate random sample data - 2000 samples, 1280 dimensions each
|
| 9 |
+
num_samples = 2000
|
| 10 |
+
dimension = 1280
|
| 11 |
+
|
| 12 |
+
# Generate random input vectors X
|
| 13 |
+
X = np.random.randn(num_samples, dimension)
|
| 14 |
+
|
| 15 |
+
# Generate target vectors Y (could be random or a function of X)
|
| 16 |
+
# Option 1: Completely random Y
|
| 17 |
+
# Y = np.random.randn(num_samples, dimension)
|
| 18 |
+
|
| 19 |
+
# Option 2: Y as a noisy function of X (more realistic for regression task)
|
| 20 |
+
W = np.random.randn(dimension, dimension) * 0.1 # Random weight matrix
|
| 21 |
+
b = np.random.randn(dimension) * 0.1 # Random bias
|
| 22 |
+
noise = np.random.randn(num_samples, dimension) * 0.05 # Random noise
|
| 23 |
+
Y = X @ W + b + noise # Y = XW + b + noise
|
| 24 |
+
|
| 25 |
+
# Convert data to PyTorch tensors
|
| 26 |
+
X_tensor = torch.tensor(X, dtype=torch.float32)
|
| 27 |
+
Y_tensor = torch.tensor(Y, dtype=torch.float32)
|
| 28 |
+
dataset = TensorDataset(X_tensor, Y_tensor)
|
| 29 |
+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
| 30 |
+
|
| 31 |
+
# Option 1: Simple Linear Regression model
|
| 32 |
+
class LinearModel(nn.Module):
|
| 33 |
+
def __init__(self):
|
| 34 |
+
super(LinearModel, self).__init__()
|
| 35 |
+
self.linear = nn.Linear(1280, 1280)
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
return self.linear(x)
|
| 39 |
+
|
| 40 |
+
# Option 2: Neural Network with hidden layers
|
| 41 |
+
class NeuralNetwork(nn.Module):
|
| 42 |
+
def __init__(self, hidden_dim=512):
|
| 43 |
+
super(NeuralNetwork, self).__init__()
|
| 44 |
+
self.network = nn.Sequential(
|
| 45 |
+
nn.Linear(1280, hidden_dim),
|
| 46 |
+
nn.ReLU(),
|
| 47 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 48 |
+
nn.ReLU(),
|
| 49 |
+
nn.Linear(hidden_dim, 1280)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
return self.network(x)
|
| 54 |
+
|
| 55 |
+
# Choose which model to use
|
| 56 |
+
# model = LinearModel()
|
| 57 |
+
model = NeuralNetwork()
|
| 58 |
+
|
| 59 |
+
# Loss function and optimizer
|
| 60 |
+
criterion = nn.MSELoss()
|
| 61 |
+
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
| 62 |
+
|
| 63 |
+
# Training loop
|
| 64 |
+
num_epochs = 50
|
| 65 |
+
for epoch in range(num_epochs):
|
| 66 |
+
total_loss = 0
|
| 67 |
+
for inputs, targets in dataloader:
|
| 68 |
+
# Forward pass
|
| 69 |
+
outputs = model(inputs)
|
| 70 |
+
loss = criterion(outputs, targets)
|
| 71 |
+
|
| 72 |
+
# Backward pass and optimize
|
| 73 |
+
optimizer.zero_grad()
|
| 74 |
+
loss.backward()
|
| 75 |
+
optimizer.step()
|
| 76 |
+
|
| 77 |
+
total_loss += loss.item()
|
| 78 |
+
|
| 79 |
+
# Print progress
|
| 80 |
+
if (epoch + 1) % 5 == 0:
|
| 81 |
+
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}')
|
| 82 |
+
|
| 83 |
+
# After training, to use the model for prediction:
|
| 84 |
+
def predict(input_vector):
|
| 85 |
+
model.eval()
|
| 86 |
+
with torch.no_grad():
|
| 87 |
+
input_tensor = torch.tensor(input_vector, dtype=torch.float32)
|
| 88 |
+
return model(input_tensor).numpy()
|
run_app.py
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import datetime
|
| 3 |
+
import sys
|
| 4 |
+
from agent.agent import SigSpace
|
| 5 |
+
import spaces
|
| 6 |
+
import gradio as gr
|
| 7 |
+
import os
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
os.environ["VLLM_USE_V1"] = "0" # Disable v1 API for now since it does not support logits processors.
|
| 13 |
+
|
| 14 |
+
# Determine the directory where the current file is located
|
| 15 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 16 |
+
os.environ["MKL_THREADING_LAYER"] = "GNU"
|
| 17 |
+
|
| 18 |
+
# Set an environment variable
|
| 19 |
+
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# Create the image path - use absolute path for reliability
|
| 23 |
+
img_path = os.path.join(current_dir, 'img', 'SigSpace.png')
|
| 24 |
+
|
| 25 |
+
def display_image(image_path):
|
| 26 |
+
# Load and return the image
|
| 27 |
+
img = Image.open(image_path)
|
| 28 |
+
return img
|
| 29 |
+
|
| 30 |
+
DESCRIPTION = f'''
|
| 31 |
+
<div style="text-align: center;">
|
| 32 |
+
<h1 style="font-size: 32px; margin-bottom: 10px;">SigSpace: An AI Agent for Tahoe-100M</h1>
|
| 33 |
+
</div>
|
| 34 |
+
'''
|
| 35 |
+
INTRO = """
|
| 36 |
+
This is the intro that goes here
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
LICENSE = """
|
| 40 |
+
License goes here
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
PLACEHOLDER = """
|
| 44 |
+
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
|
| 45 |
+
<h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Agent</h1>
|
| 46 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using Agent:</p>
|
| 47 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
|
| 48 |
+
(top-right) to remove previous context before sumbmitting a new question.</p>
|
| 49 |
+
<p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
|
| 50 |
+
</div>
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
css = """
|
| 54 |
+
h1 {
|
| 55 |
+
text-align: center;
|
| 56 |
+
display: block;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
#duplicate-button {
|
| 60 |
+
margin: auto;
|
| 61 |
+
color: white;
|
| 62 |
+
background: #1565c0;
|
| 63 |
+
border-radius: 100vh;
|
| 64 |
+
}
|
| 65 |
+
.small-button button {
|
| 66 |
+
font-size: 12px !important;
|
| 67 |
+
padding: 4px 8px !important;
|
| 68 |
+
height: 6px !important;
|
| 69 |
+
width: 4px !important;
|
| 70 |
+
}
|
| 71 |
+
.gradio-accordion {
|
| 72 |
+
margin-top: 0px !important;
|
| 73 |
+
margin-bottom: 0px !important;
|
| 74 |
+
}
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
chat_css = """
|
| 78 |
+
.gr-button { font-size: 20px !important; } /* Enlarges button icons */
|
| 79 |
+
.gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
model_name = ''
|
| 83 |
+
|
| 84 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
question_examples = [
|
| 88 |
+
# ['What is the IC50 values for the drug Abemaciclib in the cell line A549?'],
|
| 89 |
+
["What's the MoA of the drug Ponatinib on the HCT15 colon cancer cell line? Please synthesize results from the Tahoe-100M dataset, the jump dataset, and the IC50 dataset."],
|
| 90 |
+
["Natural perturbation: find the disease perturbation that has the similar effect to Glycyrrhizic acid on CVCL_0334? use the result and what you know to explain the mechanism of action."],
|
| 91 |
+
["Mechanism of action: give me the mechanism of action for drug name Abemaciclib provided by Tahoe."],
|
| 92 |
+
["Vision scores: what are the top 5 vision scores for cell line A549 and drug name Abemaciclib"]
|
| 93 |
+
]
|
| 94 |
+
|
| 95 |
+
new_tool_files = {
|
| 96 |
+
'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
config_path = "/home/ubuntu/.lambda_api_config.yaml"
|
| 100 |
+
agent = SigSpace(config_path)
|
| 101 |
+
# agent.init_model()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def update_model_parameters(enable_finish, enable_rag, enable_summary,
|
| 105 |
+
init_rag_num, step_rag_num, skip_last_k,
|
| 106 |
+
summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
|
| 107 |
+
# Update model instance parameters dynamically
|
| 108 |
+
updated_params = agent.update_parameters(
|
| 109 |
+
enable_finish=enable_finish,
|
| 110 |
+
enable_rag=enable_rag,
|
| 111 |
+
enable_summary=enable_summary,
|
| 112 |
+
init_rag_num=init_rag_num,
|
| 113 |
+
step_rag_num=step_rag_num,
|
| 114 |
+
skip_last_k=skip_last_k,
|
| 115 |
+
summary_mode=summary_mode,
|
| 116 |
+
summary_skip_last_k=summary_skip_last_k,
|
| 117 |
+
summary_context_length=summary_context_length,
|
| 118 |
+
force_finish=force_finish,
|
| 119 |
+
seed=seed,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return updated_params
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def update_seed():
|
| 126 |
+
# Update model instance parameters dynamically
|
| 127 |
+
seed = random.randint(0, 10000)
|
| 128 |
+
updated_params = agent.update_parameters(
|
| 129 |
+
seed=seed,
|
| 130 |
+
)
|
| 131 |
+
return updated_params
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
|
| 135 |
+
print("Updated seed:", update_seed())
|
| 136 |
+
new_history = history[:retry_data.index]
|
| 137 |
+
previous_prompt = history[retry_data.index]['content']
|
| 138 |
+
|
| 139 |
+
print("previous_prompt", previous_prompt)
|
| 140 |
+
|
| 141 |
+
yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
PASSWORD = "mypassword"
|
| 145 |
+
|
| 146 |
+
# Function to check if the password is correct
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def check_password(input_password):
|
| 150 |
+
if input_password == PASSWORD:
|
| 151 |
+
return gr.update(visible=True), ""
|
| 152 |
+
else:
|
| 153 |
+
return gr.update(visible=False), "Incorrect password, try again!"
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
conversation_state = gr.State([])
|
| 157 |
+
|
| 158 |
+
# Gradio block
|
| 159 |
+
chatbot = gr.Chatbot(height=400, placeholder=PLACEHOLDER,
|
| 160 |
+
label='SigSpace', type="messages", show_copy_button=True)
|
| 161 |
+
|
| 162 |
+
with gr.Blocks(css=css) as demo:
|
| 163 |
+
gr.Markdown(DESCRIPTION)
|
| 164 |
+
# gr.Markdown(INTRO)
|
| 165 |
+
gr.Image(value=display_image(img_path), label="", show_label=False, height=600, width=600)
|
| 166 |
+
default_temperature = 0.3
|
| 167 |
+
default_max_new_tokens = 1024
|
| 168 |
+
default_max_tokens = 81920
|
| 169 |
+
default_max_round = 30
|
| 170 |
+
temperature_state = gr.State(value=default_temperature)
|
| 171 |
+
max_new_tokens_state = gr.State(value=default_max_new_tokens)
|
| 172 |
+
max_tokens_state = gr.State(value=default_max_tokens)
|
| 173 |
+
max_round_state = gr.State(value=default_max_round)
|
| 174 |
+
chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
|
| 175 |
+
max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
|
| 176 |
+
|
| 177 |
+
gr.ChatInterface(
|
| 178 |
+
fn=agent.run_gradio_chat,
|
| 179 |
+
chatbot=chatbot,
|
| 180 |
+
fill_height=False, fill_width=False, stop_btn=True,
|
| 181 |
+
additional_inputs_accordion=gr.Accordion(
|
| 182 |
+
label="⚙️ Inference Parameters", open=False, render=False),
|
| 183 |
+
additional_inputs=[
|
| 184 |
+
temperature_state, max_new_tokens_state, max_tokens_state,
|
| 185 |
+
gr.Checkbox(
|
| 186 |
+
label="Activate X", value=False, render=False),
|
| 187 |
+
conversation_state,
|
| 188 |
+
max_round_state,
|
| 189 |
+
gr.Number(label="Seed", value=100, render=False)
|
| 190 |
+
],
|
| 191 |
+
examples=question_examples,
|
| 192 |
+
cache_examples=False,
|
| 193 |
+
css=chat_css,
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
with gr.Accordion("Settings", open=False):
|
| 197 |
+
|
| 198 |
+
# Define the sliders
|
| 199 |
+
temperature_slider = gr.Slider(
|
| 200 |
+
minimum=0,
|
| 201 |
+
maximum=1,
|
| 202 |
+
step=0.1,
|
| 203 |
+
value=default_temperature,
|
| 204 |
+
label="Temperature"
|
| 205 |
+
)
|
| 206 |
+
max_new_tokens_slider = gr.Slider(
|
| 207 |
+
minimum=128,
|
| 208 |
+
maximum=4096,
|
| 209 |
+
step=1,
|
| 210 |
+
value=default_max_new_tokens,
|
| 211 |
+
label="Max new tokens"
|
| 212 |
+
)
|
| 213 |
+
max_tokens_slider = gr.Slider(
|
| 214 |
+
minimum=128,
|
| 215 |
+
maximum=32000,
|
| 216 |
+
step=1,
|
| 217 |
+
value=default_max_tokens,
|
| 218 |
+
label="Max tokens"
|
| 219 |
+
)
|
| 220 |
+
max_round_slider = gr.Slider(
|
| 221 |
+
minimum=0,
|
| 222 |
+
maximum=50,
|
| 223 |
+
step=1,
|
| 224 |
+
value=default_max_round,
|
| 225 |
+
label="Max round")
|
| 226 |
+
|
| 227 |
+
# Automatically update states when slider values change
|
| 228 |
+
temperature_slider.change(
|
| 229 |
+
lambda x: x, inputs=temperature_slider, outputs=temperature_state)
|
| 230 |
+
max_new_tokens_slider.change(
|
| 231 |
+
lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
|
| 232 |
+
max_tokens_slider.change(
|
| 233 |
+
lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
|
| 234 |
+
max_round_slider.change(
|
| 235 |
+
lambda x: x, inputs=max_round_slider, outputs=max_round_state)
|
| 236 |
+
|
| 237 |
+
# password_input = gr.Textbox(
|
| 238 |
+
# label="Enter Password for More Settings", type="password")
|
| 239 |
+
# incorrect_message = gr.Textbox(visible=False, interactive=False)
|
| 240 |
+
# with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
|
| 241 |
+
# with gr.Row():
|
| 242 |
+
# with gr.Column(scale=1):
|
| 243 |
+
# with gr.Accordion("⚙️ Model Loading", open=False):
|
| 244 |
+
# model_name_input = gr.Textbox(
|
| 245 |
+
# label="Enter model path", value=model_name)
|
| 246 |
+
# load_model_btn = gr.Button(value="Load Model")
|
| 247 |
+
# load_model_btn.click(
|
| 248 |
+
# agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
|
| 249 |
+
# with gr.Column(scale=1):
|
| 250 |
+
# with gr.Accordion("⚙️ Functional Parameters", open=False):
|
| 251 |
+
# # Create Gradio components for parameter inputs
|
| 252 |
+
# enable_finish = gr.Checkbox(
|
| 253 |
+
# label="Enable Finish", value=True)
|
| 254 |
+
# enable_rag = gr.Checkbox(
|
| 255 |
+
# label="Enable RAG", value=True)
|
| 256 |
+
# enable_summary = gr.Checkbox(
|
| 257 |
+
# label="Enable Summary", value=False)
|
| 258 |
+
# init_rag_num = gr.Number(
|
| 259 |
+
# label="Initial RAG Num", value=0)
|
| 260 |
+
# step_rag_num = gr.Number(
|
| 261 |
+
# label="Step RAG Num", value=10)
|
| 262 |
+
# skip_last_k = gr.Number(label="Skip Last K", value=0)
|
| 263 |
+
# summary_mode = gr.Textbox(
|
| 264 |
+
# label="Summary Mode", value='step')
|
| 265 |
+
# summary_skip_last_k = gr.Number(
|
| 266 |
+
# label="Summary Skip Last K", value=0)
|
| 267 |
+
# summary_context_length = gr.Number(
|
| 268 |
+
# label="Summary Context Length", value=None)
|
| 269 |
+
# force_finish = gr.Checkbox(
|
| 270 |
+
# label="Force FinalAnswer", value=True)
|
| 271 |
+
# seed = gr.Number(label="Seed", value=100)
|
| 272 |
+
# # Button to submit and update parameters
|
| 273 |
+
# submit_btn = gr.Button("Update Parameters")
|
| 274 |
+
|
| 275 |
+
# # Display the updated parameters
|
| 276 |
+
# updated_parameters_output = gr.JSON()
|
| 277 |
+
|
| 278 |
+
# # When button is clicked, update parameters
|
| 279 |
+
# submit_btn.click(fn=update_model_parameters,
|
| 280 |
+
# inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
|
| 281 |
+
# summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
|
| 282 |
+
# outputs=updated_parameters_output)
|
| 283 |
+
# Button to submit the password
|
| 284 |
+
# submit_button = gr.Button("Submit")
|
| 285 |
+
|
| 286 |
+
# # When the button is clicked, check if the password is correct
|
| 287 |
+
# submit_button.click(
|
| 288 |
+
# check_password,
|
| 289 |
+
# inputs=password_input,
|
| 290 |
+
# outputs=[protected_accordion, incorrect_message]
|
| 291 |
+
# )
|
| 292 |
+
gr.Markdown(LICENSE)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
if __name__ == "__main__":
|
| 296 |
+
demo.launch(share=True)
|
tahoe_model/apply_linear_model.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import anndata
|
| 2 |
+
import joblib
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
from sklearn.model_selection import train_test_split
|
| 6 |
+
from sklearn.linear_model import LinearRegression
|
| 7 |
+
from sklearn.metrics import mean_squared_error
|
| 8 |
+
from scipy.stats import pearsonr
|
| 9 |
+
import os
|
| 10 |
+
import pandas as pd
|
| 11 |
+
|
| 12 |
+
merged_anndata = anndata.read_h5ad("data/tahoe_vision_universal_embeddings.h5ad")
|
| 13 |
+
|
| 14 |
+
X = merged_anndata.obsm["X_delta"] # 60125 x 1280
|
| 15 |
+
Y = merged_anndata.X # 60125 x 7467
|
| 16 |
+
labels = merged_anndata.var.index.tolist() # 7467
|
| 17 |
+
|
| 18 |
+
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.2, random_state = 42)
|
| 19 |
+
|
| 20 |
+
if os.path.exists("models/linear_regression_model.pkl"):
|
| 21 |
+
model = joblib.load("models/linear_regression_model.pkl")
|
| 22 |
+
|
| 23 |
+
y_pred_test = model.predict(X_test)
|
| 24 |
+
test_pearson = [pearsonr(y_test[:, i], y_pred_test[:, i])[0] for i in range(y_test.shape[1])]
|
| 25 |
+
|
| 26 |
+
top_gene_set_indices = np.argsort(test_pearson)[-20:][::-1]
|
| 27 |
+
|
| 28 |
+
top_gene_sets = [(test_pearson[i], labels[i]) for i in top_gene_set_indices]
|
| 29 |
+
|
| 30 |
+
print("Top 20 gene sets with the highest correlation:")
|
| 31 |
+
for correlation, gene_set in top_gene_sets:
|
| 32 |
+
print(f"gene set {gene_set}: pearson correlation = {correlation:.4f}")
|
| 33 |
+
|
| 34 |
+
plt.hist(test_pearson, bins=50, color='blue', alpha=0.7)
|
| 35 |
+
plt.title("Distribution of Pearson Correlation Coefficients (Test Set)")
|
| 36 |
+
plt.xlabel("Pearson Correlation Coefficient")
|
| 37 |
+
plt.ylabel("Frequency")
|
| 38 |
+
plt.grid(axis='y', alpha=0.75)
|
| 39 |
+
|
| 40 |
+
if not os.path.exists("figures"):
|
| 41 |
+
os.makedirs("figures")
|
| 42 |
+
|
| 43 |
+
plt.savefig("figures/pearson_correlation_distribution.png")
|
| 44 |
+
|
| 45 |
+
top_20_indices_per_row = np.argsort(np.abs(y_test), axis=1)[:, -20:]
|
| 46 |
+
|
| 47 |
+
correlations = []
|
| 48 |
+
for i in range(y_test.shape[0]):
|
| 49 |
+
actual_top_20 = y_test[i, top_20_indices_per_row[i]]
|
| 50 |
+
predicted_top_20 = y_pred_test[i, top_20_indices_per_row[i]]
|
| 51 |
+
correlation = pearsonr(actual_top_20, predicted_top_20)[0]
|
| 52 |
+
correlations.append(correlation)
|
| 53 |
+
|
| 54 |
+
average_correlation = np.mean(correlations)
|
| 55 |
+
print(f"Average correlation for top 20 magnitude gene sets per row: {average_correlation:.4f}")
|
| 56 |
+
|
| 57 |
+
else:
|
| 58 |
+
model = LinearRegression()
|
| 59 |
+
|
| 60 |
+
model.fit(X_train, y_train)
|
| 61 |
+
|
| 62 |
+
y_pred_train = model.predict(X_train)
|
| 63 |
+
y_pred_test = model.predict(X_test)
|
| 64 |
+
|
| 65 |
+
train_mse = mean_squared_error(y_train, y_pred_train)
|
| 66 |
+
test_mse = mean_squared_error(y_test, y_pred_test)
|
| 67 |
+
|
| 68 |
+
print(f"training MSE: {train_mse}")
|
| 69 |
+
print(f"testing MSE: {test_mse}")
|
| 70 |
+
|
| 71 |
+
joblib.dump(model, "models/linear_regression_model.pkl")
|
| 72 |
+
|
| 73 |
+
model = joblib.load("models/linear_regression_model.pkl")
|
| 74 |
+
|
| 75 |
+
disease_deltas = anndata.read_h5ad("data/disease_deltas.h5ad")
|
| 76 |
+
predicted_vision_signatures = model.predict(disease_deltas.X)
|
| 77 |
+
|
| 78 |
+
dataframe = pd.DataFrame(predicted_vision_signatures, columns = labels)
|
| 79 |
+
|
| 80 |
+
labels_combined = disease_deltas.obs.apply(
|
| 81 |
+
lambda row: f"{row['cell_type']}_{row['tissue']}_{row['disease']}", axis=1
|
| 82 |
+
).tolist()
|
| 83 |
+
|
| 84 |
+
top_20_gene_sets = []
|
| 85 |
+
|
| 86 |
+
for index, row in dataframe.iterrows():
|
| 87 |
+
top_20_indices = np.argsort(np.abs(row))[-20:][::-1]
|
| 88 |
+
top_20 = [(labels[i], "down" if row.iloc[i] < 0 else "up") for i in top_20_indices]
|
| 89 |
+
top_20_gene_sets.append(top_20)
|
| 90 |
+
|
| 91 |
+
with open("top_20_gene_sets.txt", "w") as f:
|
| 92 |
+
for i, gene_set in enumerate(top_20_gene_sets):
|
| 93 |
+
f.write(f"{labels_combined[i]}\t" + "\t".join([f"{gene}:{direction}" for gene, direction in gene_set]) + "\n")
|
tahoe_model/compute_tahoe_deltas.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import anndata
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
uce = anndata.read_h5ad("data/tahoe_universal_embeddings.h5ad")
|
| 5 |
+
|
| 6 |
+
control_condition = "[('DMSO_TF', 0.0, 'uM')]"
|
| 7 |
+
|
| 8 |
+
X_delta = np.zeros_like(uce.obsm["X_uce"])
|
| 9 |
+
|
| 10 |
+
for cell_line in uce.obs["cell_line"].unique():
|
| 11 |
+
for plate in uce.obs["plate"].unique():
|
| 12 |
+
cell_plate_mask = (uce.obs["cell_line"] == cell_line) & (uce.obs["plate"] == plate)
|
| 13 |
+
control_mask = cell_plate_mask & (uce.obs["drugname_drugconc"] == control_condition)
|
| 14 |
+
|
| 15 |
+
cell_plate_indices = np.where(cell_plate_mask)[0]
|
| 16 |
+
control_indices = np.where(control_mask)[0]
|
| 17 |
+
|
| 18 |
+
X_delta[cell_plate_indices] = uce.obsm["X_uce"][cell_plate_indices] - uce.obsm["X_uce"][control_indices]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
uce.obsm["X_delta"] = X_delta
|
| 22 |
+
|
| 23 |
+
print("X_uce shape", uce.obsm["X_uce"].shape)
|
| 24 |
+
print("X_delta shape", uce.obsm["X_delta"].shape)
|
| 25 |
+
|
| 26 |
+
uce.write("data/tahoe_universal_embeddings_deltas.h5ad")
|
tahoe_model/merge_tahoe_vision.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import anndata
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
uce = anndata.read_h5ad("data/tahoe_universal_embeddings_deltas.h5ad")
|
| 5 |
+
vision = anndata.read_h5ad("data/tahoe_vision_scores.h5ad")
|
| 6 |
+
|
| 7 |
+
uce.obs = uce.obs.reset_index().rename(columns = {"index": "condition"})
|
| 8 |
+
vision.obs["condition"] = vision.obs.apply(
|
| 9 |
+
lambda row: f"{row['Cell_ID_Cellosaur']}_{[(row['drug'], row['concentration'], row['concentration_unit'])]}_plate{row['plate']}", axis = 1
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
unique_cell_lines_uce = uce.obs["cell_line"].unique().tolist()
|
| 13 |
+
unique_drugs_uce = uce.obs["drugname_drugconc"].apply(
|
| 14 |
+
lambda x: eval(x)[0][0]
|
| 15 |
+
).unique().tolist()
|
| 16 |
+
|
| 17 |
+
# print("number of unique cell lines:", len(unique_cell_lines_uce))
|
| 18 |
+
# print(unique_cell_lines_uce)
|
| 19 |
+
|
| 20 |
+
# print("\nnumber of unique drugs:", len(unique_drugs_uce))
|
| 21 |
+
# print(unique_drugs_uce)
|
| 22 |
+
|
| 23 |
+
conditions_uce = set(uce.obs["condition"].unique())
|
| 24 |
+
conditions_vision = set(vision.obs["condition"].unique())
|
| 25 |
+
|
| 26 |
+
only_in_uce = conditions_uce - conditions_vision
|
| 27 |
+
only_in_vision = conditions_vision - conditions_uce
|
| 28 |
+
|
| 29 |
+
with open("conditions_only_in_uce.txt", "w") as f:
|
| 30 |
+
for condition in only_in_uce:
|
| 31 |
+
f.write(f"{condition}\n")
|
| 32 |
+
|
| 33 |
+
with open("conditions_only_in_vision.txt", "w") as f:
|
| 34 |
+
for condition in only_in_vision:
|
| 35 |
+
f.write(f"{condition}\n")
|
| 36 |
+
|
| 37 |
+
vision = vision[vision.obs["condition"].drop_duplicates(keep = "first").index, :]
|
| 38 |
+
vision.obs = vision.obs.reset_index(drop = True)
|
| 39 |
+
|
| 40 |
+
merged_obs = pd.merge(
|
| 41 |
+
uce.obs,
|
| 42 |
+
vision.obs,
|
| 43 |
+
on = "condition",
|
| 44 |
+
how = "inner"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
indices_in_vision = vision.obs.index[
|
| 48 |
+
vision.obs["condition"].isin(merged_obs["condition"])
|
| 49 |
+
].tolist()
|
| 50 |
+
|
| 51 |
+
indices_in_uce = uce.obs.index[
|
| 52 |
+
uce.obs["condition"].isin(merged_obs["condition"])
|
| 53 |
+
].tolist()
|
| 54 |
+
|
| 55 |
+
indices_in_vision = [int(x) for x in indices_in_vision]
|
| 56 |
+
indices_in_uce = [int(x) for x in indices_in_uce]
|
| 57 |
+
|
| 58 |
+
anndata_merged = anndata.AnnData(
|
| 59 |
+
X = vision.X[indices_in_vision, :],
|
| 60 |
+
obs = merged_obs,
|
| 61 |
+
var = vision.var,
|
| 62 |
+
obsm = {"X_uce" : uce.obsm["X_uce"][indices_in_uce, :],
|
| 63 |
+
"X_delta": uce.obsm["X_delta"][indices_in_uce, :]}
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
anndata_merged.write("data/tahoe_vision_universal_embeddings.h5ad")
|