Sid01123 commited on
Commit
9735e01
·
1 Parent(s): 54aa369

all the code

Browse files
README.md ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - tahoebio/Tahoe-100M
5
+ tags:
6
+ - tahoe-deepdive
7
+ - hackathon
8
+ - tahoe-100M
9
+ ---
10
+
11
+ <div align="center">
12
+ <img src="img/SigSpace.png" alt="SigSpace Logo" width="400"/>
13
+ </div>
14
+
15
+ # SigSpace: An AI Agent for the Tahoe-100M dataset
16
+ This is a submission for the **Tahoe-DeepDive Hackathon 2025**.
17
+
18
+ # Team Name
19
+ SigSpace
20
+
21
+ ## Members
22
+ - Ishita Mangla
23
+ - Kuan Pang
24
+ - Giovanni Palla
25
+ - Yanay Rosen
26
+ - Sid Sanghi
27
+ - Yasha Ektefaie
28
+ - Rohit Khurana
29
+
30
+ # Project
31
+ ## SigSpace: An AI Agent for the Tahoe-100M dataset
32
+
33
+ ## Overview
34
+ We have developed an AI agent that accesses the Tahoe-100M dataset along with publicly available and novel datasets. This agent works to refine and expand the mechanisms of action (MOA) and drug signatures of the perturbations within the Tahoe-100M dataset.
35
+
36
+ ## Motivation
37
+ Drug discovery in the age of Large Language Models (LLMs) can be enhanced through agentic workflows that parse diverse sources of unstructured information to synthesize and connect hypotheses across different fields and modalities. However, these models are primarily trained on text data and lack the capacity to effectively interrogate rich biological databases with complex, biologically-motivated queries. In this work, we provide a proof of concept demonstrating how the Tahoe-100M dataset can be integrated with publicly available relevant datasets to expand the hypothesis space for mechanisms of action and drug responses in the perturbations tested in the Tahoe-100M dataset.
38
+
39
+ ## Methods
40
+ We have curated new datasets that enhance the description of drugs and cell-lines present in the Tahoe-100M dataset.
41
+
42
+ Specifically:
43
+ - TAHOE-100M: vision scores and metadata.
44
+ - PRISM: We use PRISM drug sensitivity data, which reports the concentration of a compound needed to inhibit 50% of cancer cell viability. Measurements are based on pooled screening of barcoded cell lines and provide a high-throughput assessment of drug response across a large panel of cancer models.
45
+ - NCI60: We use NCI-60 LC50 data, which reports the concentration of a drug that kills 50% of the cells present at the time of drug addition. It is measured across a panel of 60 human cancer cell lines using standardized multi-dose assays.
46
+ - JUMP: We use the JUMP dataset, which captures morphological profiles of cells in response to chemical and genetic perturbations. High-content imaging and automated feature extraction are used to quantify cellular changes, enabling large-scale profiling of perturbation effects across diverse biological contexts.
47
+ - UCE-CXG-EMBEDDING: natural perturbation search using AI virtual cell.
48
+
49
+ ## Results
50
+
51
+ We have developed a Gradio application that accesses these databases and performs complex queries, enhancing and grounding the reasoning in real biological measurements.
52
+
53
+ ## Discussion
agent/__pycache__/agent.cpython-313.pyc ADDED
Binary file (23.8 kB). View file
 
agent/__pycache__/prompt.cpython-313.pyc ADDED
Binary file (3.7 kB). View file
 
agent/__pycache__/utils.cpython-313.pyc ADDED
Binary file (4.22 kB). View file
 
agent/agent.py ADDED
@@ -0,0 +1,676 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from agent.utils import *
2
+ from agent.prompt import *
3
+ import anndata
4
+ import gradio as gr
5
+ from gradio import ChatMessage
6
+ import re
7
+ import pandas as pd
8
+ import pathlib
9
+ import numpy as np
10
+
11
+ class SigSpace(Basic_Agent):
12
+ def __init__(self, config_path:str):
13
+ super().__init__(config_path)
14
+ self.conversation = []
15
+ self.system_prompt = Agent_Prompt
16
+ self.conversation = []
17
+ self.conversation.append({"role": "system", "content": self.system_prompt})
18
+
19
+ # initialize data for jump
20
+ path = pathlib.Path("/home/ubuntu/giovanni/code/Tahoe_Hackathon/datasets")
21
+ # jump_path = pathlib.Path("/home/ubuntu/giovanni/data")
22
+ self.jump_tahoe_drug_metadata = pd.read_csv(path/"drug_metadata_inchikey.csv")
23
+ self.jump_similarity_score = pd.read_csv(path/"compound_genetic_perturbation_cosine_similarity_inchikey.csv")
24
+
25
+ # Load PRISM IC50 matrix
26
+ # prism_data_path = pathlib.Path("/home/ubuntu/sid/Hackathon_Tahoe/data")
27
+ self.ic50 = pd.read_csv(path / "Tahoe_PRISM_cell_by_drug_ic50_matrix_named.csv", index_col=0)
28
+ self.ic50.columns = self.ic50.columns.str.lower()
29
+
30
+ # nci60_path = pathlib.Path("/home/ubuntu/ishita/tahoe/")
31
+ self.lc50 = pd.read_csv(path / "filtered_results.csv")
32
+ # Filter out rows where CELL is nan
33
+ self.lc50 = self.lc50[self.lc50['CELL'].notna()]
34
+
35
+ # Load full Tahoe metadata
36
+ # tahoe_path = pathlib.Path("/home/ubuntu/rohit/data")
37
+ self.tahoe_cell_meta = pd.read_csv(path / "cell_line_metadata.csv")
38
+ self.tahoe_drug_meta = pd.read_csv(path / "drug_metadata.csv")
39
+ self.tahoe_vision_scores = anndata.read_h5ad(path / "tahoe_vision_scores.h5ad")
40
+
41
+ # Load PRISM subset of Tahoe metadata
42
+ self.prism_tahoe_cell_meta = pd.read_csv(path / "Tahoe_PRISM_matched_cell_metadata_final.csv")
43
+ self.prism_tahoe_drug_meta = pd.read_csv(path / "Tahoe_PRISM_matched_drug_metadata_final.csv")
44
+
45
+ # Build cell line common name to depmap_id map (strip whitespace and case)
46
+ self.cell_name_to_depmap = {
47
+ row["cell_name"].strip(): row["Cell_ID_DepMap"]
48
+ for _, row in self.prism_tahoe_cell_meta.iterrows()
49
+ }
50
+
51
+ self.cell_name_to_depmap_lc50 = {
52
+ row["clean"].strip(): row["cell_line_name"]
53
+ for _, row in self.lc50.iterrows()
54
+ }
55
+
56
+ self.tahoe_similarity_score = pd.read_csv(path / "in_tahoe_search_result_df.csv")
57
+ self.tahoe_cxg_similarity_score = pd.read_csv(path / "cxg_search_result_df.csv")
58
+
59
+
60
+ def initialize_conversation(self, message, conversation=None, history=None):
61
+ if conversation is None:
62
+ conversation = []
63
+
64
+ conversation.append({"role": "system", "content" : Agent_Prompt})
65
+
66
+ if history is not None:
67
+ if len(history) == 0:
68
+ conversation = []
69
+ print("clear conversation successfully")
70
+ else:
71
+ for i in range(len(history)):
72
+ if history[i]['role'] == 'user':
73
+ if i-1 >= 0 and history[i-1]['role'] == 'assistant':
74
+ conversation.append(
75
+ {"role": "assistant", "content": history[i-1]['content']})
76
+ conversation.append(
77
+ {"role": "user", "content": history[i]['content']})
78
+ if i == len(history)-1 and history[i]['role'] == 'assistant':
79
+ conversation.append(
80
+ {"role": "assistant", "content": history[i]['content']})
81
+
82
+ conversation.append({"role": "user", "content": message})
83
+
84
+ return conversation
85
+
86
+ def get_similar_disease(self, disease_name, k_value):
87
+ if disease_name != "Alzheimer's":
88
+ return "FAIL"
89
+ return 'Parkinsons Disease'
90
+
91
+ def get_validated_target_jump(self, drug_name):
92
+ print(drug_name)
93
+ try:
94
+ inchikey = self.jump_tahoe_drug_metadata[self.jump_tahoe_drug_metadata.drug.isin([drug_name])]["InChIKey"].values[0]
95
+ similarity_scores = self.jump_similarity_score[self.jump_similarity_score.InChIKey.isin([inchikey])]
96
+
97
+ # Count ORF entries with cosine_similarity > 0.2 and < -0.2
98
+ orf_positive = similarity_scores[(similarity_scores.Genetic_Perturbation == 'ORF') & (similarity_scores.cosine_sim > 0.2)].shape[0]
99
+ orf_negative = similarity_scores[(similarity_scores.Genetic_Perturbation == 'ORF') & (similarity_scores.cosine_sim < -0.2)].shape[0]
100
+
101
+ # Count CRISPR entries with cosine_similarity > 0.2 and < -0.2
102
+ crispr_positive = similarity_scores[(similarity_scores.Genetic_Perturbation == 'CRISPR') & (similarity_scores.cosine_sim > 0.2)].shape[0]
103
+ crispr_negative = similarity_scores[(similarity_scores.Genetic_Perturbation == 'CRISPR') & (similarity_scores.cosine_sim < -0.2)].shape[0]
104
+
105
+ orf_targets = f"ORF: {orf_positive} positive correlations (>0.2), {orf_negative} negative correlations (<-0.2)"
106
+ crispr_targets = f"CRISPR: {crispr_positive} positive correlations (>0.2), {crispr_negative} negative correlations (<-0.2)"
107
+
108
+ orf_crispr_targets = orf_targets + " " +crispr_targets
109
+
110
+ known_targets_from_jump = self.jump_tahoe_drug_metadata[self.jump_tahoe_drug_metadata.drug.isin([drug_name])]["target_list"].values[0]
111
+ known_targets_output = f"The known targets from the JUMP dataset are: {', '.join(known_targets_from_jump.split('|'))}"
112
+ except Exception as e:
113
+ print(e)
114
+ return "For the drug {drug_name}, we were not able to find the target in the JUMP dataset."
115
+
116
+ orf_crispr_targets = \
117
+ f"""
118
+ Preturbation description:
119
+
120
+ ORF: The ORF perturbation consists of an overexpression of the target gene.
121
+ CRISPR: The CRISPR perturbation consists of a knockout of the target gene.
122
+
123
+ Considering the drug "{drug_name}", we expect positive correlations with shared CRISPR targets,
124
+ and negative correlations with shared ORF targets.
125
+
126
+ But, the measured correlations are:
127
+
128
+ {orf_crispr_targets}
129
+
130
+ Furthermore, the JUMP dataset has the following known targets for the drug "{drug_name}":
131
+
132
+ {known_targets_output}
133
+ """
134
+ return orf_crispr_targets
135
+
136
+ def get_similar_drug_effect_in_tahoe(self, cell_line_name: str, drug_name: str):
137
+ """
138
+ Get similar effect drugs in tahoe based on the drug name and cell line name.
139
+
140
+ Args:
141
+ cell_line_name (str): The name of the cell line.
142
+ drug_name (str): The name of the drug.
143
+ """
144
+ cell_line_names = self.tahoe_similarity_score["source_cell_line"].unique().tolist()
145
+ drug_names = self.tahoe_similarity_score["source_drug_name"].unique().tolist()
146
+ if cell_line_name not in cell_line_names:
147
+ return "FAIL: Cell line name not found in the dataset. A example: CVCL_0218"
148
+ if drug_name not in drug_names:
149
+ return "FAIL: Drug name not found in the dataset. A example: Daptomycin"
150
+ hits = self.tahoe_similarity_score[
151
+ (self.tahoe_similarity_score["source_cell_line"] == cell_line_name) &
152
+ (self.tahoe_similarity_score["source_drug_name"] == drug_name)
153
+ ]
154
+ # sort by distance
155
+ hits = hits.sort_values(by="distance", ascending=True).reset_index(drop=True)
156
+ hits = hits.head(10)
157
+ # keep target_drug_name and target_cell_line
158
+ hits = hits[["target_drug_name", "target_cell_line",]]
159
+ outputs = f"""
160
+ The following drugs have similar effects to the drug you provided:
161
+ hits:
162
+ {hits}
163
+ """
164
+ return outputs
165
+
166
+ def get_similar_drug_effects_in_cxg(self, cell_line_name: str, drug_name: str):
167
+ """
168
+ Get similar effect diseases in cxg based on the drug name and cell line name.
169
+
170
+ Args:
171
+ cell_line_name (str): The name of the cell line.
172
+ drug_name (str): The name of the drug.
173
+ """
174
+ cell_line_names = self.tahoe_cxg_similarity_score["cell_line"].unique().tolist()
175
+ drug_names = self.tahoe_cxg_similarity_score["perturbation_drug_name"].unique().tolist()
176
+ if cell_line_name not in cell_line_names:
177
+ return "FAIL: Cell line name not found in the dataset. A valid example: CVCL_0218"
178
+ if drug_name not in drug_names:
179
+ return "FAIL: Drug name not found in the dataset. A valid example:: Daptomycin"
180
+ hits = self.tahoe_cxg_similarity_score[
181
+ (self.tahoe_cxg_similarity_score["cell_line"] == cell_line_name) &
182
+ (self.tahoe_cxg_similarity_score["perturbation_drug_name"] == drug_name)
183
+ ]
184
+ hits = hits.sort_values(by="distance", ascending=True).reset_index(drop=True)
185
+ hits = hits.head(10)
186
+ # keeps cell_type tissue_type and disease
187
+ hits = hits[["cell_type", "tissue_type", "disease"]]
188
+ outputs = f"""
189
+ The following diseases have similar effects to the drug you provided:
190
+ hits:
191
+ {hits}
192
+ """
193
+ return outputs
194
+
195
+ def get_ic50_prism(self, drug_name: str, cell_line_name: str):
196
+ drug_name_lower = drug_name.strip().lower()
197
+ cell_line_key = cell_line_name.strip()
198
+
199
+ if cell_line_key not in self.cell_name_to_depmap:
200
+ print(f"Cell line name '{cell_line_key}' not found for PRISM data")
201
+ return f"FAIL: Cell line name '{cell_line_key}' not found for PRISM data"
202
+
203
+ depmap_id = self.cell_name_to_depmap[cell_line_key]
204
+
205
+ if drug_name_lower not in self.ic50.columns:
206
+ print(f"Drug name '{drug_name}' not found in IC50 matrix columns.")
207
+ return f"FAIL: Drug name '{drug_name}' not found in IC50 matrix columns."
208
+
209
+ try:
210
+ ic50_val = self.ic50.loc[depmap_id, drug_name_lower]
211
+ if pd.isna(ic50_val):
212
+ print(f"FAIL: IC50 value is missing for '{drug_name}' in cell line '{cell_line_name}' (DepMap ID: {depmap_id}).")
213
+ return f"FAIL: IC50 value is missing for '{drug_name}' in cell line '{cell_line_name}' (DepMap ID: {depmap_id})."
214
+
215
+ return (
216
+ f"The IC50 value of {ic50_val:.4f} corresponds to the log10-transformed micromolar concentration "
217
+ f"at which {drug_name} inhibits 50% of viability in the {cell_line_name} cell line "
218
+ f"(DepMap ID: {depmap_id}).\n\n"
219
+ "This value comes from the PRISM Repurposing Secondary Screen, which exposes pooled barcoded cell lines "
220
+ "to drug treatment for 5 days and infers viability from barcode abundance using sequencing.\n\n"
221
+ "The secondary screen includes higher-confidence compound–cell line pairs with improved replicability "
222
+ "compared to the primary screen.\n\n"
223
+ "Lower IC50 values indicate greater sensitivity of the cell line to the drug."
224
+ )
225
+ except KeyError as e:
226
+ print(f"Combination not found: {e}")
227
+ return None
228
+
229
+
230
+ def clean_cell_line_name(self, name):
231
+ """
232
+ Standardize cell line names for comparison by:
233
+ 1. Converting to string (handles any non-string values)
234
+ 2. Converting to uppercase
235
+ 3. Removing all non-alphanumeric characters
236
+
237
+ Args:
238
+ name: Cell line name (string or other type)
239
+
240
+ Returns:
241
+ Cleaned string with only uppercase letters and numbers
242
+ """
243
+ return re.sub(r"[^A-Z0-9]", "", str(name).upper())
244
+
245
+ def get_lc50_nci60(self, drug_name: str, cell_line_name: str):
246
+ cell_line_name = cell_line_name.upper()
247
+ cell_line_key = self.clean_cell_line_name(cell_line_name)
248
+
249
+ if cell_line_key not in self.cell_name_to_depmap_lc50:
250
+ print(f"Cell line name '{cell_line_key}' not found for NCI60 data")
251
+ return None
252
+ depmap_id = self.cell_name_to_depmap_lc50[cell_line_key]
253
+ print ("Depmap_id", depmap_id)
254
+
255
+ # Find the drug in NCI60 dataset
256
+ # Since drugs are in uppercase in the list, convert search term to uppercase
257
+ drug_name_upper = drug_name.strip().upper()
258
+
259
+ # Filter rows where the drug name is in the drug column
260
+ # This assumes drugs in each row are comma-separated or in a format that can be searched
261
+ matching_row = self.lc50[self.lc50['drug'].str.contains(drug_name_upper, na=False)]
262
+ print ("Matching row", matching_row)
263
+ if matching_row.empty:
264
+ print(f"Drug name '{drug_name}' not found in NCI60 dataset.")
265
+ return None
266
+
267
+ if matching_row.empty:
268
+ raise ValueError(f"Multiple matches found for drug '{drug_name}' in NCI60 dataset.")
269
+
270
+ print ("Matching row", matching_row)
271
+ # Get the LC50 value from the matching row
272
+ lc50_val = matching_row.iloc[0]['NLOGLC50']
273
+ lconc_val = matching_row.iloc[0]['LCONC']
274
+
275
+ if pd.isna(lc50_val):
276
+ return "LC50 value is missing for '{drug_name}' in cell line '{cell_line_name}' (depmap_id: {depmap_id})."
277
+
278
+ lc50_output = \
279
+ f"""
280
+ The LC50 value of {lc50_val} represents -log10(LC50), the negative base-10 logarithm of the molar concentration that inhibits 50% of cell growth.
281
+
282
+ Higher LC50 values therefore indicate greater drug potency.
283
+
284
+ The LCONC value of {lconc_val} denotes the maximum log10 molar concentration tested in the dilution series—for example, LCONC = -4 corresponds to 10^-4 M.
285
+
286
+ Both metrics come from the NCI-60 drug screen, which applies a standardized 48-hour exposure assay across all compound–cell-line pairs."
287
+ """
288
+
289
+ return lc50_output
290
+
291
+ def load_gene_sets_file(self, file_path):
292
+ """
293
+ Load gene sets from a tab-delimited file where the first column is the gene set name
294
+ and the remaining columns are gene symbols.
295
+
296
+ Parameters:
297
+ -----------
298
+ file_path : str
299
+ Path to the gene sets file
300
+
301
+ Returns:
302
+ --------
303
+ dict
304
+ Dictionary mapping gene set names to lists of genes
305
+ """
306
+ gene_sets = {}
307
+ with open(file_path, 'r') as file:
308
+ for line in file:
309
+ parts = line.strip().split('\t')
310
+ if parts:
311
+ set_name = parts[0]
312
+ genes = [gene for gene in parts[1:] if gene] # Filter out empty strings
313
+ gene_sets[set_name] = genes
314
+ return gene_sets
315
+
316
+ def get_genes_for_set(self, set_name):
317
+ """
318
+ Get the list of genes for a specific gene set.
319
+
320
+ Parameters:
321
+ -----------
322
+ set_name : str
323
+ Name of the gene set to query
324
+
325
+ Returns:
326
+ --------
327
+ list
328
+ List of genes in the gene set, or empty list if set not found
329
+ """
330
+ if not hasattr(self, 'gene_sets'):
331
+ # Load the gene sets file if it hasn't been loaded yet
332
+ self.gene_sets = self.load_gene_sets_file('/home/ubuntu/ishita/msigdb_all_sigs_human_symbols.txt')
333
+
334
+ return self.gene_sets.get(set_name, [])
335
+
336
+ def rank_vision_scores(self, drug_name: str, cell_line_name: str, k_value: int):
337
+ self.tahoe_vision_scores.X = (self.tahoe_vision_scores.X - np.mean(self.tahoe_vision_scores.X, axis = 0)) / np.std(self.tahoe_vision_scores.X, axis = 0)
338
+
339
+ # subset to the drug / cell line at the highest tested concentration
340
+ filt = (
341
+ (self.tahoe_vision_scores.obs["Cell_Name_Vevo"] == cell_line_name)
342
+ & (self.tahoe_vision_scores.obs["drug"] == drug_name)
343
+ )
344
+ filtered_scores = self.tahoe_vision_scores[filt]
345
+ if filtered_scores.n_obs == 0:
346
+ return "VISION scores not found for this drug–cell-line combination."
347
+
348
+ filtered_scores = filtered_scores[
349
+ filtered_scores.obs["concentration"] == filtered_scores.obs["concentration"].max()
350
+ ]
351
+
352
+ # pick top-|score| gene sets
353
+ top_idx = np.argsort(-np.abs(filtered_scores.X[0]))[:k_value]
354
+ gene_sets = filtered_scores.var.index[top_idx].tolist()
355
+ scores = filtered_scores.X[0, top_idx].tolist()
356
+
357
+ # build the narrative
358
+ header = (
359
+ "VISION scores are single-cell gene-set enrichment values computed by the "
360
+ "VISION algorithm (DeTomaso & Yosef 2021). Positive scores indicate relative "
361
+ "up-regulation of the gene set in the queried condition; negative scores indicate "
362
+ "down-regulation.\n"
363
+ )
364
+ lines = []
365
+ for gs, val in zip(gene_sets, scores):
366
+ gs_name = gs.replace("gs_", "")
367
+ genes = self.get_genes_for_set(gs_name)
368
+ direction = "up-regulated" if val > 0 else "down-regulated" if val < 0 else "not changed"
369
+ lines.append(f"{gs} has gene set {genes} : {direction} (VISION score = {val:.3f})")
370
+
371
+ return header + "\n".join(lines)
372
+
373
+ def obtain_moa(self, drug_name: str):
374
+ row = self.tahoe_drug_meta[self.tahoe_drug_meta["drug"] == drug_name]
375
+
376
+ if row.empty:
377
+ return "MOA annotation not found for this drug."
378
+
379
+ moa_broad = row["moa-broad"].values[0]
380
+ moa_fine = row["moa-fine"].values[0]
381
+
382
+ return (
383
+ f"Broad MOA: {moa_broad}; "
384
+ f"Fine MOA: {moa_fine}. "
385
+ "Fine-grained mechanism of action (MOA) annotation for the drug, "
386
+ "specifying the biological process or molecular target affected. "
387
+ "Derived from MedChemExpress and curated with GPT-based annotations."
388
+ )
389
+
390
+ def obtain_gene_targets(self, drug_name: str):
391
+ row = self.tahoe_drug_meta[self.tahoe_drug_meta["drug"] == drug_name]
392
+ if row.empty:
393
+ return "Gene targets not found for this drug."
394
+
395
+ targets = row["targets"].values[0]
396
+
397
+ # Convert a stringified list/dict to a Python object, if necessary.
398
+ if isinstance(targets, str):
399
+ try:
400
+ targets = eval(targets)
401
+ except Exception: # fall back to treating it as a single ID
402
+ targets = [targets]
403
+
404
+ return (
405
+ f"Gene target token IDs: {targets}. "
406
+ "Gene identifiers (integer token IDs) corresponding to each gene with non-zero expression in the cell."
407
+ )
408
+
409
+ def obtain_cell_line_data(self, cell_line_name: str):
410
+ row = self.tahoe_cell_meta[self.tahoe_cell_meta["cell_name"] == cell_line_name]
411
+
412
+ if row.empty:
413
+ return "Cell-line metadata not found for this cell line."
414
+
415
+ organ = row["Organ"].values[0]
416
+ driver_gene_symbol = row["Driver_Gene_Symbol"].values[0]
417
+ driver_varzyg = row["Driver_VarZyg"].values[0]
418
+ driver_vartype = row["Driver_VarType"].values[0]
419
+ driver_proteffect = row["Driver_ProtEffect_or_CdnaEffect"].values[0]
420
+ driver_mech_inferdm = row["Driver_Mech_InferDM"].values[0]
421
+ driver_genetype_dm = row["Driver_GeneType_DM"].values[0]
422
+
423
+ return (
424
+ f"Organ: {organ}; "
425
+ f"Driver_Gene_Symbol: {driver_gene_symbol}; "
426
+ f"Driver_VarZyg: {driver_varzyg}; "
427
+ f"Driver_VarType: {driver_vartype}; "
428
+ f"Driver_ProtEffect_or_CdnaEffect: {driver_proteffect}; "
429
+ f"Driver_Mech_InferDM: {driver_mech_inferdm}; "
430
+ f"Driver_GeneType_DM: {driver_genetype_dm}. "
431
+ "Organ = tissue or organ of origin for the cell line (e.g., Lung), used to interpret lineage-specific responses. "
432
+ "Driver_Gene_Symbol = HGNC-approved symbol of a driver gene with functional alterations in this cell line. "
433
+ "Driver_VarZyg = zygosity of the driver variant (Hom = homozygous, Het = heterozygous). "
434
+ "Driver_VarType = type of genetic alteration (e.g., Missense, Frameshift, Stopgain). "
435
+ "Driver_ProtEffect_or_CdnaEffect = precise protein or cDNA-level annotation of the mutation (e.g., p.G12S). "
436
+ "Driver_Mech_InferDM = inferred functional mechanism (LoF = loss-of-function, GoF = gain-of-function). "
437
+ "Driver_GeneType_DM = classification of the driver gene as an Oncogene or Suppressor."
438
+ )
439
+
440
+ def run_gradio_chat(self, message: str,
441
+ history: list,
442
+ temperature: float,
443
+ max_new_tokens: int,
444
+ max_token: int,
445
+ call_agent: bool,
446
+ conversation: gr.State,
447
+ max_round: int = 20,
448
+ seed: int = None,
449
+ call_agent_level: int = 0,
450
+ sub_agent_task: str = None):
451
+
452
+ print("\033[1;32;40mstart\033[0m")
453
+ print("len(message)", len(message))
454
+
455
+ if len(message) <= 10:
456
+ yield "Hi, I am Agent, an assistant for answering biomedical questions. Please provide a valid message with a string longer than 10 characters."
457
+ return "Please provide a valid message."
458
+
459
+ outputs = []
460
+ outputs_str = ''
461
+ last_outputs = []
462
+
463
+ conversation = self.initialize_conversation(
464
+ message,
465
+ conversation=conversation,
466
+ history=history)
467
+
468
+ history = []
469
+
470
+ next_round = True
471
+ function_call_messages = []
472
+ current_round = 0
473
+ enable_summary = False
474
+ last_status = {} # for summary
475
+ token_overflow = False
476
+ # if self.enable_checker:
477
+ # checker = ReasoningTraceChecker(
478
+ # message, conversation, init_index=len(conversation))
479
+
480
+ # try:
481
+ self.conversation.append({"role": "user", "content": message})
482
+ while next_round and current_round < max_round:
483
+ current_round += 1
484
+
485
+ response = self.llm_infer(self.conversation)
486
+ self.conversation.append({"role": "system", "content": response})
487
+ tool_called = False
488
+ print(response)
489
+ # import pdb; pdb.set_trace()
490
+
491
+ if 'Tool-call:' in response:
492
+ match = re.search(r"Tool-call:\s*(.*)", response, re.DOTALL)
493
+ response_text = match.group(1).strip()
494
+ if "None" not in response_text and response_text.replace('-', '').rstrip().replace('FINISHED', '').rstrip():
495
+ history.append(ChatMessage(
496
+ role="assistant", content=f"{response.replace('FINISHED', '').split('</think>')[1]}"))
497
+ yield history
498
+
499
+ tool_called = True
500
+ print(response_text)
501
+ if "FAIL" in response_text:
502
+ self.conversation.append({"role": "system", "content": tool_response})
503
+ history.append(
504
+ ChatMessage(role="assistant", content=f"Response from tool FAILED ")
505
+ )
506
+ next_round = False
507
+ yield history
508
+ else:
509
+ tool_call_text = response_text
510
+ if ';' in tool_call_text:
511
+ tool_calls = [i.replace('\n', '').rstrip('-').replace('FINISHED', '').replace('Response:', '') for i in tool_call_text.split(';') if i]
512
+ elif '\n' in tool_call_text:
513
+ tool_calls = [i.replace('\n', '').rstrip('-').replace('FINISHED', '').replace('Response:', '') for i in tool_call_text.split('\n') if i]
514
+ else:
515
+ tool_calls = [tool_call_text]
516
+
517
+ tool_calls = [i.rstrip('-') for i in tool_calls if i]
518
+
519
+ for call in tool_calls:
520
+ print(f"\033[1;34;40mCalling this command now {call}\033[0m")
521
+ tool_response = str(eval(call))
522
+ self.conversation.append({"role": "system", "content": tool_response})
523
+ history.append(
524
+ ChatMessage(role="assistant", content=f"Response from tool: {tool_response}")
525
+ )
526
+ print(f"\033[1;34;40mGot this response {tool_response}\033[0m")
527
+ yield history
528
+ else:
529
+ history.append(
530
+ ChatMessage(role="assistant", content=f"{response}")
531
+ )
532
+ yield history
533
+
534
+ elif 'Response:' in response or tool_called is False:
535
+ match = re.search(r"Response:\s*(.*)", response, re.DOTALL)
536
+ response_text = match.group(1).strip().replace('Tool-call: None', '')
537
+ print(f"\033[1;33;40mresponse text: {response_text}\033[0m")
538
+ history.append(
539
+ ChatMessage(
540
+ role="assistant", content=f"{response_text.replace('FINISHED', '')}")
541
+ )
542
+ yield history
543
+
544
+ if 'FINISHED' in response and tool_called is False:
545
+ next_round = False
546
+
547
+
548
+
549
+
550
+
551
+
552
+ # if len(last_outputs) > 0:
553
+ # function_call_messages, picked_tools_prompt, special_tool_call, current_gradio_history = yield from self.run_function_call_stream(
554
+ # last_outputs, return_message=True,
555
+ # existing_tools_prompt=picked_tools_prompt,
556
+ # message_for_call_agent=message,
557
+ # call_agent=call_agent,
558
+ # call_agent_level=call_agent_level,
559
+ # temperature=temperature)
560
+ # history.extend(current_gradio_history)
561
+ # if special_tool_call == 'Finish':
562
+ # yield history
563
+ # next_round = False
564
+ # conversation.extend(function_call_messages)
565
+ # return function_call_messages[0]['content']
566
+ # elif special_tool_call == 'RequireClarification' or special_tool_call == 'DirectResponse':
567
+ # history.append(
568
+ # ChatMessage(role="assistant", content=history[-1].content))
569
+ # yield history
570
+ # next_round = False
571
+ # return history[-1].content
572
+ # if (self.enable_summary or token_overflow) and not call_agent:
573
+ # if token_overflow:
574
+ # print("token_overflow, using summary")
575
+ # enable_summary = True
576
+ # last_status = self.function_result_summary(
577
+ # conversation, status=last_status,
578
+ # enable_summary=enable_summary)
579
+ # if function_call_messages is not None:
580
+ # conversation.extend(function_call_messages)
581
+ # formated_md_function_call_messages = tool_result_format(
582
+ # function_call_messages)
583
+ # yield history
584
+ # else:
585
+ # next_round = False
586
+ # conversation.extend(
587
+ # [{"role": "assistant", "content": ''.join(last_outputs)}])
588
+ # return ''.join(last_outputs).replace("</s>", "")
589
+ # # if self.enable_checker:
590
+ # # good_status, wrong_info = checker.check_conversation()
591
+ # # if not good_status:
592
+ # # next_round = False
593
+ # # print("Internal error in reasoning: " + wrong_info)
594
+ # # break
595
+ # last_outputs = []
596
+ # last_outputs_str, token_overflow = self.llm_infer(
597
+ # messages=conversation,
598
+ # temperature=temperature,
599
+ # tools=picked_tools_prompt,
600
+ # skip_special_tokens=False,
601
+ # max_new_tokens=max_new_tokens,
602
+ # max_token=max_token,
603
+ # seed=seed,
604
+ # check_token_status=True)
605
+ # last_thought = last_outputs_str.split("[TOOL_CALLS]")[0]
606
+ # for each in history:
607
+ # if each.metadata is not None:
608
+ # each.metadata['status'] = 'done'
609
+ # if '[FinalAnswer]' in last_thought:
610
+ # final_thought, final_answer = last_thought.split(
611
+ # '[FinalAnswer]')
612
+ # history.append(
613
+ # ChatMessage(role="assistant",
614
+ # content=final_thought.strip())
615
+ # )
616
+ # yield history
617
+ # history.append(
618
+ # ChatMessage(
619
+ # role="assistant", content="**Answer**:\n"+final_answer.strip())
620
+ # )
621
+ # yield history
622
+ # else:
623
+ # history.append(ChatMessage(
624
+ # role="assistant", content=last_thought))
625
+ # yield history
626
+
627
+ # last_outputs.append(last_outputs_str)
628
+
629
+ # if self.force_finish:
630
+ # last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
631
+ # conversation, temperature, max_new_tokens, max_token, return_full_thought=True)
632
+ # for each in history:
633
+ # if each.metadata is not None:
634
+ # each.metadata['status'] = 'done'
635
+
636
+ # final_thought, final_answer = last_outputs_str.split('[FinalAnswer]')
637
+ # history.append(
638
+ # ChatMessage(role="assistant",
639
+ # content=final_thought.strip())
640
+ # )
641
+ # yield history
642
+ # history.append(
643
+ # ChatMessage(
644
+ # role="assistant", content="**Answer**:\n"+final_answer.strip())
645
+ # )
646
+ # yield history
647
+ # else:
648
+ # yield "The number of rounds exceeds the maximum limit!"
649
+
650
+ # except Exception as e:
651
+ # print(f"Error: {e}")
652
+ # if self.force_finish:
653
+ # last_outputs_str = self.get_answer_based_on_unfinished_reasoning(
654
+ # conversation,
655
+ # temperature,
656
+ # max_new_tokens,
657
+ # max_token,
658
+ # return_full_thought=True)
659
+ # for each in history:
660
+ # if each.metadata is not None:
661
+ # each.metadata['status'] = 'done'
662
+
663
+ # final_thought, final_answer = last_outputs_str.split(
664
+ # '[FinalAnswer]')
665
+ # history.append(
666
+ # ChatMessage(role="assistant",
667
+ # content=final_thought.strip())
668
+ # )
669
+ # yield history
670
+ # history.append(
671
+ # ChatMessage(
672
+ # role="assistant", content="**Answer**:\n"+final_answer.strip())
673
+ # )
674
+ # yield history
675
+ # else:
676
+ # return None
agent/prompt.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Agent_Prompt = """
2
+
3
+ You are an assistant who is helping the user identify novel gene sets for a particular disease. All your responses must be in the following format.
4
+ If you don't use a tool then don't include a tool-call, if you don't need to respond to the user and instead want to solely call a tool then don't include a response.
5
+ Return FINISHED at the end of a response if you have responded to the user query. DO NOT hallucinate, guess, or ASSUME tool responses. If you need to call multiple tools separate the tool-calls with a semi-colon; :
6
+
7
+ Reasoning:
8
+
9
+ [Your reasoning goes here]
10
+
11
+ Response:
12
+
13
+ [Your response goes here, if necessary]
14
+
15
+ Tool-call:
16
+
17
+ [Tool call goes here, if necessary]
18
+
19
+ ------------------------------------------------
20
+
21
+ The tools you have in your disposal are:
22
+
23
+ (1) A tool which can tell you the k-most diseases that are similar to your query disease.
24
+
25
+ The tool call for this agent is: "self.get_similar_disease(disease_name, k_value)" where disease_name must be a string and k_value must be an integer. The output of this tool is a list of disease names.
26
+
27
+ (2) A tool which can retrieve the gene targets validated from JUMP-CP dataset.
28
+
29
+ The tool call for this agent is: "self.get_validated_target_jump(drug_name)" where drug_name must be a string. The output of this tool is a list of gene targets.
30
+
31
+ (3) A tool which can retrieve an IC50 value for a drug and cell line from the PRISM Repurposing 20Q2 dataset.
32
+
33
+ The tool call for this agent is: "self.get_ic50_prism(drug_name, cell_line)" where drug_name and cell_line must be strings. The output of this tool is scalar IC50 floating point value. These are not keyword arguments.
34
+
35
+ (4) A tool which can retrieve gene-set expression scores from the Tahoe-100M dataset.
36
+
37
+ The tool call for this agent is "self.rank_vision_scores(drug_name, cell_line, k_value)" where drug_name and cell_line must be strings and k_value must be an integer. These are not keyword arguments. The output of this tool is a list of tuples, where each tuple contains a gene-set name and its corresponding expression score.
38
+
39
+ (5) A tool which can obtain the mechanism of action for a drug from the Tahoe-100M dataset.
40
+
41
+ The tool call for this agent is "self.obtain_moa(drug_name)" where drug_name must be a string. This is not a keyword argument. The output of this tool is dictionary that contains a broad mechanism of action and a more specific mechanism of action.
42
+
43
+ (6) A tool which can retrieve the gene targets for a drug from the Tahoe-100M dataset.
44
+
45
+ The tool call for this agent is: "self.obtain_gene_targets(drug_name)" where drug_name must be a string. This is not a keyword argument. The output of this tool is a list of gene symbols representing the known molecular targets of the compound.
46
+
47
+ (7) A tool which can retrieve the cell line metadata from the Tahoe-100M dataset.
48
+
49
+ The tool call for this agent is: "self.obtain_cell_line_data(cell_line_name)" where cell_line_name must be a string. This is not a keyword argument. The output of this tool is a dictionary containing information about key driver mutations for each cell line.
50
+
51
+ (8) A tool which can retrieve the LC50 value for a drug and cell line from the NCI-60 dataset.
52
+
53
+ The tool call for this agent is: "self.get_lc50_nci60(drug_name, cell_line_name)" where drug_name and cell_line_name must be strings. These are not keyword arguments.The output of this tool is a tuple of (LC50, LCONC). The LC50 value is in log10 scale and the LCONC is a scalar value that is in log10 scale. It is thelog of the highest concentration tested.
54
+
55
+ (9) A tool which searches for the similar perturbation effect within the Tahoe dataset.
56
+
57
+ The tool call for this agent is: "self.get_similar_drug_effect_in_tahoe(cell_line_name, drug_name)" where cell_line_name and drug_name must be strings. These are not keyword arguments. The output of this tool is a string of dataframe that tells you about what other **drugs** in tahoe have similar perturbation effect on the cell line.
58
+
59
+ (10) A tool that does natural perturbation search in Cellxgene database for similar drug effect.
60
+
61
+ The tool call for this agent is: "self.get_similar_drug_effects_in_cxg(cell_line_name, drug_name)" where cell_line_name and drug_name must be strings. These are not keyword arguments. The output of this tool is a string of dataframe that tells you about what **disease** and cell types in cellxgene have similar perturbation effect on the cell line.
62
+
63
+
64
+ """
agent/utils.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import AzureOpenAI, OpenAI
2
+ import yaml
3
+
4
+
5
+ class Basic_Agent():
6
+
7
+ def __init__(self, config):
8
+ self.config = self.load_config(config)
9
+ self.openai_api_key = self.config['openai_api_key']
10
+ if 'open_api_base' in self.config:
11
+ self.open_api_base = self.config['open_api_base']
12
+ self.azure_openai_api_key = self.config['azure_openai_api_key']
13
+ self.azure_openai_endpoint = self.config['azure_openai_endpoint']
14
+ self.openai_backend = self.config['openai_backend']
15
+ # self.pqapi_token = self.config['pqapi_token']
16
+ # os.environ['PQA_API_TOKEN'] = self.pqapi_token
17
+
18
+ def load_config(self,config_file):
19
+ with open(config_file, 'r') as file:
20
+ return yaml.safe_load(file)
21
+
22
+ def llm_infer(self, conversation, temp = 0.000000001, max_tokens = 1000, image = None, role = None):
23
+
24
+ while True:
25
+
26
+ if self.openai_backend == 'azure':
27
+ client = AzureOpenAI(
28
+ azure_endpoint = self.azure_openai_endpoint,
29
+ api_key=self.azure_openai_api_key,
30
+ api_version="2024-05-01-preview")
31
+
32
+ response = client.chat.completions.create(
33
+ model='gpt-4o',
34
+ messages = conversation,
35
+ temperature=temp,
36
+ max_tokens=max_tokens,
37
+ )
38
+ elif self.openai_backend == 'openai':
39
+ client = OpenAI(
40
+ api_key=self.openai_api_key
41
+ )
42
+
43
+ response = client.chat.completions.create(
44
+ model='gpt-4o',
45
+ messages=conversation,
46
+ temperature=temp,
47
+ max_tokens=max_tokens,
48
+ )
49
+ elif self.openai_backend == 'lambda':
50
+
51
+ client = OpenAI(api_key = self.openai_api_key,
52
+ base_url = self.open_api_base)
53
+
54
+ model = "deepseek-r1-671b"
55
+ response = client.chat.completions.create(
56
+ model = model,
57
+ messages = conversation)
58
+ else:
59
+ raise ValueError(f"Invalid openai_backend: {self.openai_backend}")
60
+
61
+ if "I'm sorry, I can't assist with that" in response.choices[0].message.content or "I'm unable to view the image" in response.choices[0].message.content or "I'm unable to provide a definitive answer" in response.choices[0].message.content:
62
+ print("Failed to generate response, trying again")
63
+ continue
64
+ else:
65
+ response = response.choices[0].message.content
66
+ return response
67
+
68
+ def run_function(self, output):
69
+ try:
70
+ tool_call = output.split('Tool-call:')[-1].rstrip().replace('\n', '')
71
+ res = eval(tool_call)
72
+ return res
73
+ except Exception as e:
74
+ print(f"Error in parsing tool call in {output} got this error {e}")
75
+ import pdb; pdb.set_trace()
data/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Data Download
2
+
3
+ Download jump signatures
4
+
5
+ ```
6
+ wget https://cellpainting-gallery.s3.amazonaws.com/cpg0016-jump-assembled/source_all/workspace/profiles/jump-profiling-recipe_2024_a917fa7/COMPOUND/profiles_var_mad_int_featselect_harmony/profiles_var_mad_int.parquet
7
+ ```
data/jump-dataset.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
data/jump-similarity.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anndata
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.utils.data import DataLoader, TensorDataset
6
+ import numpy as np
7
+
8
+ # Generate random sample data - 2000 samples, 1280 dimensions each
9
+ num_samples = 2000
10
+ dimension = 1280
11
+
12
+ # Generate random input vectors X
13
+ X = np.random.randn(num_samples, dimension)
14
+
15
+ # Generate target vectors Y (could be random or a function of X)
16
+ # Option 1: Completely random Y
17
+ # Y = np.random.randn(num_samples, dimension)
18
+
19
+ # Option 2: Y as a noisy function of X (more realistic for regression task)
20
+ W = np.random.randn(dimension, dimension) * 0.1 # Random weight matrix
21
+ b = np.random.randn(dimension) * 0.1 # Random bias
22
+ noise = np.random.randn(num_samples, dimension) * 0.05 # Random noise
23
+ Y = X @ W + b + noise # Y = XW + b + noise
24
+
25
+ # Convert data to PyTorch tensors
26
+ X_tensor = torch.tensor(X, dtype=torch.float32)
27
+ Y_tensor = torch.tensor(Y, dtype=torch.float32)
28
+ dataset = TensorDataset(X_tensor, Y_tensor)
29
+ dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
30
+
31
+ # Option 1: Simple Linear Regression model
32
+ class LinearModel(nn.Module):
33
+ def __init__(self):
34
+ super(LinearModel, self).__init__()
35
+ self.linear = nn.Linear(1280, 1280)
36
+
37
+ def forward(self, x):
38
+ return self.linear(x)
39
+
40
+ # Option 2: Neural Network with hidden layers
41
+ class NeuralNetwork(nn.Module):
42
+ def __init__(self, hidden_dim=512):
43
+ super(NeuralNetwork, self).__init__()
44
+ self.network = nn.Sequential(
45
+ nn.Linear(1280, hidden_dim),
46
+ nn.ReLU(),
47
+ nn.Linear(hidden_dim, hidden_dim),
48
+ nn.ReLU(),
49
+ nn.Linear(hidden_dim, 1280)
50
+ )
51
+
52
+ def forward(self, x):
53
+ return self.network(x)
54
+
55
+ # Choose which model to use
56
+ # model = LinearModel()
57
+ model = NeuralNetwork()
58
+
59
+ # Loss function and optimizer
60
+ criterion = nn.MSELoss()
61
+ optimizer = optim.Adam(model.parameters(), lr=0.001)
62
+
63
+ # Training loop
64
+ num_epochs = 50
65
+ for epoch in range(num_epochs):
66
+ total_loss = 0
67
+ for inputs, targets in dataloader:
68
+ # Forward pass
69
+ outputs = model(inputs)
70
+ loss = criterion(outputs, targets)
71
+
72
+ # Backward pass and optimize
73
+ optimizer.zero_grad()
74
+ loss.backward()
75
+ optimizer.step()
76
+
77
+ total_loss += loss.item()
78
+
79
+ # Print progress
80
+ if (epoch + 1) % 5 == 0:
81
+ print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(dataloader):.4f}')
82
+
83
+ # After training, to use the model for prediction:
84
+ def predict(input_vector):
85
+ model.eval()
86
+ with torch.no_grad():
87
+ input_tensor = torch.tensor(input_vector, dtype=torch.float32)
88
+ return model(input_tensor).numpy()
run_app.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import datetime
3
+ import sys
4
+ from agent.agent import SigSpace
5
+ import spaces
6
+ import gradio as gr
7
+ import os
8
+ from PIL import Image
9
+
10
+ import os
11
+
12
+ os.environ["VLLM_USE_V1"] = "0" # Disable v1 API for now since it does not support logits processors.
13
+
14
+ # Determine the directory where the current file is located
15
+ current_dir = os.path.dirname(os.path.abspath(__file__))
16
+ os.environ["MKL_THREADING_LAYER"] = "GNU"
17
+
18
+ # Set an environment variable
19
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
20
+
21
+
22
+ # Create the image path - use absolute path for reliability
23
+ img_path = os.path.join(current_dir, 'img', 'SigSpace.png')
24
+
25
+ def display_image(image_path):
26
+ # Load and return the image
27
+ img = Image.open(image_path)
28
+ return img
29
+
30
+ DESCRIPTION = f'''
31
+ <div style="text-align: center;">
32
+ <h1 style="font-size: 32px; margin-bottom: 10px;">SigSpace: An AI Agent for Tahoe-100M</h1>
33
+ </div>
34
+ '''
35
+ INTRO = """
36
+ This is the intro that goes here
37
+ """
38
+
39
+ LICENSE = """
40
+ License goes here
41
+ """
42
+
43
+ PLACEHOLDER = """
44
+ <div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
45
+ <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">Agent</h1>
46
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Tips before using Agent:</p>
47
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Please click clear🗑️
48
+ (top-right) to remove previous context before sumbmitting a new question.</p>
49
+ <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.55;">Click retry🔄 (below message) to get multiple versions of the answer.</p>
50
+ </div>
51
+ """
52
+
53
+ css = """
54
+ h1 {
55
+ text-align: center;
56
+ display: block;
57
+ }
58
+
59
+ #duplicate-button {
60
+ margin: auto;
61
+ color: white;
62
+ background: #1565c0;
63
+ border-radius: 100vh;
64
+ }
65
+ .small-button button {
66
+ font-size: 12px !important;
67
+ padding: 4px 8px !important;
68
+ height: 6px !important;
69
+ width: 4px !important;
70
+ }
71
+ .gradio-accordion {
72
+ margin-top: 0px !important;
73
+ margin-bottom: 0px !important;
74
+ }
75
+ """
76
+
77
+ chat_css = """
78
+ .gr-button { font-size: 20px !important; } /* Enlarges button icons */
79
+ .gr-button svg { width: 32px !important; height: 32px !important; } /* Enlarges SVG icons */
80
+ """
81
+
82
+ model_name = ''
83
+
84
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
85
+
86
+
87
+ question_examples = [
88
+ # ['What is the IC50 values for the drug Abemaciclib in the cell line A549?'],
89
+ ["What's the MoA of the drug Ponatinib on the HCT15 colon cancer cell line? Please synthesize results from the Tahoe-100M dataset, the jump dataset, and the IC50 dataset."],
90
+ ["Natural perturbation: find the disease perturbation that has the similar effect to Glycyrrhizic acid on CVCL_0334? use the result and what you know to explain the mechanism of action."],
91
+ ["Mechanism of action: give me the mechanism of action for drug name Abemaciclib provided by Tahoe."],
92
+ ["Vision scores: what are the top 5 vision scores for cell line A549 and drug name Abemaciclib"]
93
+ ]
94
+
95
+ new_tool_files = {
96
+ 'new_tool': os.path.join(current_dir, 'data', 'new_tool.json'),
97
+ }
98
+
99
+ config_path = "/home/ubuntu/.lambda_api_config.yaml"
100
+ agent = SigSpace(config_path)
101
+ # agent.init_model()
102
+
103
+
104
+ def update_model_parameters(enable_finish, enable_rag, enable_summary,
105
+ init_rag_num, step_rag_num, skip_last_k,
106
+ summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed):
107
+ # Update model instance parameters dynamically
108
+ updated_params = agent.update_parameters(
109
+ enable_finish=enable_finish,
110
+ enable_rag=enable_rag,
111
+ enable_summary=enable_summary,
112
+ init_rag_num=init_rag_num,
113
+ step_rag_num=step_rag_num,
114
+ skip_last_k=skip_last_k,
115
+ summary_mode=summary_mode,
116
+ summary_skip_last_k=summary_skip_last_k,
117
+ summary_context_length=summary_context_length,
118
+ force_finish=force_finish,
119
+ seed=seed,
120
+ )
121
+
122
+ return updated_params
123
+
124
+
125
+ def update_seed():
126
+ # Update model instance parameters dynamically
127
+ seed = random.randint(0, 10000)
128
+ updated_params = agent.update_parameters(
129
+ seed=seed,
130
+ )
131
+ return updated_params
132
+
133
+
134
+ def handle_retry(history, retry_data: gr.RetryData, temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round):
135
+ print("Updated seed:", update_seed())
136
+ new_history = history[:retry_data.index]
137
+ previous_prompt = history[retry_data.index]['content']
138
+
139
+ print("previous_prompt", previous_prompt)
140
+
141
+ yield from agent.run_gradio_chat(new_history + [{"role": "user", "content": previous_prompt}], temperature, max_new_tokens, max_tokens, multi_agent, conversation, max_round)
142
+
143
+
144
+ PASSWORD = "mypassword"
145
+
146
+ # Function to check if the password is correct
147
+
148
+
149
+ def check_password(input_password):
150
+ if input_password == PASSWORD:
151
+ return gr.update(visible=True), ""
152
+ else:
153
+ return gr.update(visible=False), "Incorrect password, try again!"
154
+
155
+
156
+ conversation_state = gr.State([])
157
+
158
+ # Gradio block
159
+ chatbot = gr.Chatbot(height=400, placeholder=PLACEHOLDER,
160
+ label='SigSpace', type="messages", show_copy_button=True)
161
+
162
+ with gr.Blocks(css=css) as demo:
163
+ gr.Markdown(DESCRIPTION)
164
+ # gr.Markdown(INTRO)
165
+ gr.Image(value=display_image(img_path), label="", show_label=False, height=600, width=600)
166
+ default_temperature = 0.3
167
+ default_max_new_tokens = 1024
168
+ default_max_tokens = 81920
169
+ default_max_round = 30
170
+ temperature_state = gr.State(value=default_temperature)
171
+ max_new_tokens_state = gr.State(value=default_max_new_tokens)
172
+ max_tokens_state = gr.State(value=default_max_tokens)
173
+ max_round_state = gr.State(value=default_max_round)
174
+ chatbot.retry(handle_retry, chatbot, chatbot, temperature_state, max_new_tokens_state,
175
+ max_tokens_state, gr.Checkbox(value=False, render=False), conversation_state, max_round_state)
176
+
177
+ gr.ChatInterface(
178
+ fn=agent.run_gradio_chat,
179
+ chatbot=chatbot,
180
+ fill_height=False, fill_width=False, stop_btn=True,
181
+ additional_inputs_accordion=gr.Accordion(
182
+ label="⚙️ Inference Parameters", open=False, render=False),
183
+ additional_inputs=[
184
+ temperature_state, max_new_tokens_state, max_tokens_state,
185
+ gr.Checkbox(
186
+ label="Activate X", value=False, render=False),
187
+ conversation_state,
188
+ max_round_state,
189
+ gr.Number(label="Seed", value=100, render=False)
190
+ ],
191
+ examples=question_examples,
192
+ cache_examples=False,
193
+ css=chat_css,
194
+ )
195
+
196
+ with gr.Accordion("Settings", open=False):
197
+
198
+ # Define the sliders
199
+ temperature_slider = gr.Slider(
200
+ minimum=0,
201
+ maximum=1,
202
+ step=0.1,
203
+ value=default_temperature,
204
+ label="Temperature"
205
+ )
206
+ max_new_tokens_slider = gr.Slider(
207
+ minimum=128,
208
+ maximum=4096,
209
+ step=1,
210
+ value=default_max_new_tokens,
211
+ label="Max new tokens"
212
+ )
213
+ max_tokens_slider = gr.Slider(
214
+ minimum=128,
215
+ maximum=32000,
216
+ step=1,
217
+ value=default_max_tokens,
218
+ label="Max tokens"
219
+ )
220
+ max_round_slider = gr.Slider(
221
+ minimum=0,
222
+ maximum=50,
223
+ step=1,
224
+ value=default_max_round,
225
+ label="Max round")
226
+
227
+ # Automatically update states when slider values change
228
+ temperature_slider.change(
229
+ lambda x: x, inputs=temperature_slider, outputs=temperature_state)
230
+ max_new_tokens_slider.change(
231
+ lambda x: x, inputs=max_new_tokens_slider, outputs=max_new_tokens_state)
232
+ max_tokens_slider.change(
233
+ lambda x: x, inputs=max_tokens_slider, outputs=max_tokens_state)
234
+ max_round_slider.change(
235
+ lambda x: x, inputs=max_round_slider, outputs=max_round_state)
236
+
237
+ # password_input = gr.Textbox(
238
+ # label="Enter Password for More Settings", type="password")
239
+ # incorrect_message = gr.Textbox(visible=False, interactive=False)
240
+ # with gr.Accordion("⚙️ Settings", open=False, visible=False) as protected_accordion:
241
+ # with gr.Row():
242
+ # with gr.Column(scale=1):
243
+ # with gr.Accordion("⚙️ Model Loading", open=False):
244
+ # model_name_input = gr.Textbox(
245
+ # label="Enter model path", value=model_name)
246
+ # load_model_btn = gr.Button(value="Load Model")
247
+ # load_model_btn.click(
248
+ # agent.load_models, inputs=model_name_input, outputs=gr.Textbox(label="Status"))
249
+ # with gr.Column(scale=1):
250
+ # with gr.Accordion("⚙️ Functional Parameters", open=False):
251
+ # # Create Gradio components for parameter inputs
252
+ # enable_finish = gr.Checkbox(
253
+ # label="Enable Finish", value=True)
254
+ # enable_rag = gr.Checkbox(
255
+ # label="Enable RAG", value=True)
256
+ # enable_summary = gr.Checkbox(
257
+ # label="Enable Summary", value=False)
258
+ # init_rag_num = gr.Number(
259
+ # label="Initial RAG Num", value=0)
260
+ # step_rag_num = gr.Number(
261
+ # label="Step RAG Num", value=10)
262
+ # skip_last_k = gr.Number(label="Skip Last K", value=0)
263
+ # summary_mode = gr.Textbox(
264
+ # label="Summary Mode", value='step')
265
+ # summary_skip_last_k = gr.Number(
266
+ # label="Summary Skip Last K", value=0)
267
+ # summary_context_length = gr.Number(
268
+ # label="Summary Context Length", value=None)
269
+ # force_finish = gr.Checkbox(
270
+ # label="Force FinalAnswer", value=True)
271
+ # seed = gr.Number(label="Seed", value=100)
272
+ # # Button to submit and update parameters
273
+ # submit_btn = gr.Button("Update Parameters")
274
+
275
+ # # Display the updated parameters
276
+ # updated_parameters_output = gr.JSON()
277
+
278
+ # # When button is clicked, update parameters
279
+ # submit_btn.click(fn=update_model_parameters,
280
+ # inputs=[enable_finish, enable_rag, enable_summary, init_rag_num, step_rag_num, skip_last_k,
281
+ # summary_mode, summary_skip_last_k, summary_context_length, force_finish, seed],
282
+ # outputs=updated_parameters_output)
283
+ # Button to submit the password
284
+ # submit_button = gr.Button("Submit")
285
+
286
+ # # When the button is clicked, check if the password is correct
287
+ # submit_button.click(
288
+ # check_password,
289
+ # inputs=password_input,
290
+ # outputs=[protected_accordion, incorrect_message]
291
+ # )
292
+ gr.Markdown(LICENSE)
293
+
294
+
295
+ if __name__ == "__main__":
296
+ demo.launch(share=True)
tahoe_model/apply_linear_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anndata
2
+ import joblib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from sklearn.model_selection import train_test_split
6
+ from sklearn.linear_model import LinearRegression
7
+ from sklearn.metrics import mean_squared_error
8
+ from scipy.stats import pearsonr
9
+ import os
10
+ import pandas as pd
11
+
12
+ merged_anndata = anndata.read_h5ad("data/tahoe_vision_universal_embeddings.h5ad")
13
+
14
+ X = merged_anndata.obsm["X_delta"] # 60125 x 1280
15
+ Y = merged_anndata.X # 60125 x 7467
16
+ labels = merged_anndata.var.index.tolist() # 7467
17
+
18
+ X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size = 0.2, random_state = 42)
19
+
20
+ if os.path.exists("models/linear_regression_model.pkl"):
21
+ model = joblib.load("models/linear_regression_model.pkl")
22
+
23
+ y_pred_test = model.predict(X_test)
24
+ test_pearson = [pearsonr(y_test[:, i], y_pred_test[:, i])[0] for i in range(y_test.shape[1])]
25
+
26
+ top_gene_set_indices = np.argsort(test_pearson)[-20:][::-1]
27
+
28
+ top_gene_sets = [(test_pearson[i], labels[i]) for i in top_gene_set_indices]
29
+
30
+ print("Top 20 gene sets with the highest correlation:")
31
+ for correlation, gene_set in top_gene_sets:
32
+ print(f"gene set {gene_set}: pearson correlation = {correlation:.4f}")
33
+
34
+ plt.hist(test_pearson, bins=50, color='blue', alpha=0.7)
35
+ plt.title("Distribution of Pearson Correlation Coefficients (Test Set)")
36
+ plt.xlabel("Pearson Correlation Coefficient")
37
+ plt.ylabel("Frequency")
38
+ plt.grid(axis='y', alpha=0.75)
39
+
40
+ if not os.path.exists("figures"):
41
+ os.makedirs("figures")
42
+
43
+ plt.savefig("figures/pearson_correlation_distribution.png")
44
+
45
+ top_20_indices_per_row = np.argsort(np.abs(y_test), axis=1)[:, -20:]
46
+
47
+ correlations = []
48
+ for i in range(y_test.shape[0]):
49
+ actual_top_20 = y_test[i, top_20_indices_per_row[i]]
50
+ predicted_top_20 = y_pred_test[i, top_20_indices_per_row[i]]
51
+ correlation = pearsonr(actual_top_20, predicted_top_20)[0]
52
+ correlations.append(correlation)
53
+
54
+ average_correlation = np.mean(correlations)
55
+ print(f"Average correlation for top 20 magnitude gene sets per row: {average_correlation:.4f}")
56
+
57
+ else:
58
+ model = LinearRegression()
59
+
60
+ model.fit(X_train, y_train)
61
+
62
+ y_pred_train = model.predict(X_train)
63
+ y_pred_test = model.predict(X_test)
64
+
65
+ train_mse = mean_squared_error(y_train, y_pred_train)
66
+ test_mse = mean_squared_error(y_test, y_pred_test)
67
+
68
+ print(f"training MSE: {train_mse}")
69
+ print(f"testing MSE: {test_mse}")
70
+
71
+ joblib.dump(model, "models/linear_regression_model.pkl")
72
+
73
+ model = joblib.load("models/linear_regression_model.pkl")
74
+
75
+ disease_deltas = anndata.read_h5ad("data/disease_deltas.h5ad")
76
+ predicted_vision_signatures = model.predict(disease_deltas.X)
77
+
78
+ dataframe = pd.DataFrame(predicted_vision_signatures, columns = labels)
79
+
80
+ labels_combined = disease_deltas.obs.apply(
81
+ lambda row: f"{row['cell_type']}_{row['tissue']}_{row['disease']}", axis=1
82
+ ).tolist()
83
+
84
+ top_20_gene_sets = []
85
+
86
+ for index, row in dataframe.iterrows():
87
+ top_20_indices = np.argsort(np.abs(row))[-20:][::-1]
88
+ top_20 = [(labels[i], "down" if row.iloc[i] < 0 else "up") for i in top_20_indices]
89
+ top_20_gene_sets.append(top_20)
90
+
91
+ with open("top_20_gene_sets.txt", "w") as f:
92
+ for i, gene_set in enumerate(top_20_gene_sets):
93
+ f.write(f"{labels_combined[i]}\t" + "\t".join([f"{gene}:{direction}" for gene, direction in gene_set]) + "\n")
tahoe_model/compute_tahoe_deltas.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anndata
2
+ import numpy as np
3
+
4
+ uce = anndata.read_h5ad("data/tahoe_universal_embeddings.h5ad")
5
+
6
+ control_condition = "[('DMSO_TF', 0.0, 'uM')]"
7
+
8
+ X_delta = np.zeros_like(uce.obsm["X_uce"])
9
+
10
+ for cell_line in uce.obs["cell_line"].unique():
11
+ for plate in uce.obs["plate"].unique():
12
+ cell_plate_mask = (uce.obs["cell_line"] == cell_line) & (uce.obs["plate"] == plate)
13
+ control_mask = cell_plate_mask & (uce.obs["drugname_drugconc"] == control_condition)
14
+
15
+ cell_plate_indices = np.where(cell_plate_mask)[0]
16
+ control_indices = np.where(control_mask)[0]
17
+
18
+ X_delta[cell_plate_indices] = uce.obsm["X_uce"][cell_plate_indices] - uce.obsm["X_uce"][control_indices]
19
+
20
+
21
+ uce.obsm["X_delta"] = X_delta
22
+
23
+ print("X_uce shape", uce.obsm["X_uce"].shape)
24
+ print("X_delta shape", uce.obsm["X_delta"].shape)
25
+
26
+ uce.write("data/tahoe_universal_embeddings_deltas.h5ad")
tahoe_model/merge_tahoe_vision.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import anndata
2
+ import pandas as pd
3
+
4
+ uce = anndata.read_h5ad("data/tahoe_universal_embeddings_deltas.h5ad")
5
+ vision = anndata.read_h5ad("data/tahoe_vision_scores.h5ad")
6
+
7
+ uce.obs = uce.obs.reset_index().rename(columns = {"index": "condition"})
8
+ vision.obs["condition"] = vision.obs.apply(
9
+ lambda row: f"{row['Cell_ID_Cellosaur']}_{[(row['drug'], row['concentration'], row['concentration_unit'])]}_plate{row['plate']}", axis = 1
10
+ )
11
+
12
+ unique_cell_lines_uce = uce.obs["cell_line"].unique().tolist()
13
+ unique_drugs_uce = uce.obs["drugname_drugconc"].apply(
14
+ lambda x: eval(x)[0][0]
15
+ ).unique().tolist()
16
+
17
+ # print("number of unique cell lines:", len(unique_cell_lines_uce))
18
+ # print(unique_cell_lines_uce)
19
+
20
+ # print("\nnumber of unique drugs:", len(unique_drugs_uce))
21
+ # print(unique_drugs_uce)
22
+
23
+ conditions_uce = set(uce.obs["condition"].unique())
24
+ conditions_vision = set(vision.obs["condition"].unique())
25
+
26
+ only_in_uce = conditions_uce - conditions_vision
27
+ only_in_vision = conditions_vision - conditions_uce
28
+
29
+ with open("conditions_only_in_uce.txt", "w") as f:
30
+ for condition in only_in_uce:
31
+ f.write(f"{condition}\n")
32
+
33
+ with open("conditions_only_in_vision.txt", "w") as f:
34
+ for condition in only_in_vision:
35
+ f.write(f"{condition}\n")
36
+
37
+ vision = vision[vision.obs["condition"].drop_duplicates(keep = "first").index, :]
38
+ vision.obs = vision.obs.reset_index(drop = True)
39
+
40
+ merged_obs = pd.merge(
41
+ uce.obs,
42
+ vision.obs,
43
+ on = "condition",
44
+ how = "inner"
45
+ )
46
+
47
+ indices_in_vision = vision.obs.index[
48
+ vision.obs["condition"].isin(merged_obs["condition"])
49
+ ].tolist()
50
+
51
+ indices_in_uce = uce.obs.index[
52
+ uce.obs["condition"].isin(merged_obs["condition"])
53
+ ].tolist()
54
+
55
+ indices_in_vision = [int(x) for x in indices_in_vision]
56
+ indices_in_uce = [int(x) for x in indices_in_uce]
57
+
58
+ anndata_merged = anndata.AnnData(
59
+ X = vision.X[indices_in_vision, :],
60
+ obs = merged_obs,
61
+ var = vision.var,
62
+ obsm = {"X_uce" : uce.obsm["X_uce"][indices_in_uce, :],
63
+ "X_delta": uce.obsm["X_delta"][indices_in_uce, :]}
64
+ )
65
+
66
+ anndata_merged.write("data/tahoe_vision_universal_embeddings.h5ad")