| from typing import Any, Dict, List, Optional
|
|
|
| from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
|
|
| LLMThought,
|
| LLMThoughtLabeler,
|
| LLMThoughtState,
|
| StreamlitCallbackHandler,
|
| ToolRecord,
|
| )
|
| from langchain_core.agents import AgentAction, AgentFinish
|
| from streamlit.delta_generator import DeltaGenerator
|
|
|
| from utils import is_smiles
|
|
|
| import requests
|
| from langchain import LLMChain, PromptTemplate
|
| from langchain.chat_models import ChatOpenAI
|
| from rdkit import Chem
|
|
|
|
|
| def cdk(smiles):
|
| """
|
| Get a depiction of some smiles.
|
| """
|
|
|
| url = "https://www.simolecule.com/cdkdepict/depict/wob/svg"
|
| headers = {"Content-Type": "application/json"}
|
| response = requests.get(
|
| url,
|
| headers=headers,
|
| params={
|
| "smi": smiles,
|
| "annotate": "colmap",
|
| "zoom": 2,
|
| "w": 150,
|
| "h": 80,
|
| "abbr": "off",
|
| },
|
| )
|
| return response.text
|
|
|
|
|
| class LLMThoughtChem(LLMThought):
|
| def __init__(
|
| self,
|
| parent_container: DeltaGenerator,
|
| labeler: LLMThoughtLabeler,
|
| expanded: bool,
|
| collapse_on_complete: bool,
|
| ):
|
| super().__init__(
|
| parent_container,
|
| labeler,
|
| expanded,
|
| collapse_on_complete,
|
| )
|
|
|
| def on_tool_end(
|
| self,
|
| output: str,
|
| color: Optional[str] = None,
|
| observation_prefix: Optional[str] = None,
|
| llm_prefix: Optional[str] = None,
|
| output_ph: dict = {},
|
| input_tool: str = "",
|
| serialized: dict = {},
|
| **kwargs: Any,
|
| ) -> None:
|
|
|
| if serialized["name"] == "Name2SMILES":
|
| safe_smiles = output.replace("[", "\[").replace("]", "\]")
|
| if is_smiles(output):
|
| self._container.markdown(
|
| f"**{safe_smiles}**{cdk(output)}", unsafe_allow_html=True
|
| )
|
|
|
| if serialized["name"] == "ReactionPredict":
|
| rxn = f"{input_tool}>>{output}"
|
| safe_smiles = rxn.replace("[", "\[").replace("]", "\]")
|
| self._container.markdown(
|
| f"**{safe_smiles}**{cdk(rxn)}", unsafe_allow_html=True
|
| )
|
|
|
| if serialized["name"] == "ReactionRetrosynthesis":
|
| output = output.replace("[", "\[").replace("]", "\]")
|
|
|
| def on_tool_start(
|
| self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
| ) -> None:
|
|
|
|
|
| self._state = LLMThoughtState.RUNNING_TOOL
|
| tool_name = serialized["name"]
|
| self._last_tool = ToolRecord(name=tool_name, input_str=input_str)
|
| self._container.update(
|
| new_label=(
|
| self._labeler.get_tool_label(self._last_tool, is_complete=False)
|
| .replace("[", "\[")
|
| .replace("]", "\]")
|
| )
|
| )
|
|
|
|
|
| if serialized["name"] == "ReactionRetrosynthesis" or serialized["name"] == "LiteratureSearch":
|
| self._container.markdown(
|
| f"‼️ Note: This tool can take some time to complete execution ‼️",
|
| unsafe_allow_html=True,
|
| )
|
|
|
| def complete(self, final_label: Optional[str] = None) -> None:
|
| """Finish the thought."""
|
| if final_label is None and self._state == LLMThoughtState.RUNNING_TOOL:
|
| assert (
|
| self._last_tool is not None
|
| ), "_last_tool should never be null when _state == RUNNING_TOOL"
|
| final_label = self._labeler.get_tool_label(
|
| self._last_tool, is_complete=True
|
| )
|
| self._state = LLMThoughtState.COMPLETE
|
|
|
| final_label = final_label.replace("[", "\[").replace("]", "\]")
|
| if self._collapse_on_complete:
|
| self._container.update(new_label=final_label, new_expanded=False)
|
| else:
|
| self._container.update(new_label=final_label)
|
|
|
|
|
| class StreamlitCallbackHandlerChem(StreamlitCallbackHandler):
|
| def __init__(
|
| self,
|
| parent_container: DeltaGenerator,
|
| *,
|
| max_thought_containers: int = 4,
|
| expand_new_thoughts: bool = True,
|
| collapse_completed_thoughts: bool = True,
|
| thought_labeler: Optional[LLMThoughtLabeler] = None,
|
| output_placeholder: dict = {},
|
| ):
|
| super(StreamlitCallbackHandlerChem, self).__init__(
|
| parent_container,
|
| max_thought_containers=max_thought_containers,
|
| expand_new_thoughts=expand_new_thoughts,
|
| collapse_completed_thoughts=collapse_completed_thoughts,
|
| thought_labeler=thought_labeler,
|
| )
|
|
|
| self._output_placeholder = output_placeholder
|
| self.last_input = ""
|
|
|
| def on_llm_start(
|
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
| ) -> None:
|
| if self._current_thought is None:
|
| self._current_thought = LLMThoughtChem(
|
| parent_container=self._parent_container,
|
| expanded=self._expand_new_thoughts,
|
| collapse_on_complete=self._collapse_completed_thoughts,
|
| labeler=self._thought_labeler,
|
| )
|
|
|
| self._current_thought.on_llm_start(serialized, prompts)
|
|
|
|
|
|
|
|
|
| def on_tool_start(
|
| self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
| ) -> None:
|
| self._require_current_thought().on_tool_start(serialized, input_str, **kwargs)
|
| self._prune_old_thought_containers()
|
| self._last_input = input_str
|
| self._serialized = serialized
|
|
|
| def on_tool_end(
|
| self,
|
| output: str,
|
| color: Optional[str] = None,
|
| observation_prefix: Optional[str] = None,
|
| llm_prefix: Optional[str] = None,
|
| **kwargs: Any,
|
| ) -> None:
|
| self._require_current_thought().on_tool_end(
|
| output,
|
| color,
|
| observation_prefix,
|
| llm_prefix,
|
| output_ph=self._output_placeholder,
|
| input_tool=self._last_input,
|
| serialized=self._serialized,
|
| **kwargs,
|
| )
|
| self._complete_current_thought()
|
|
|
| def on_agent_finish(
|
| self, finish: AgentFinish, color: Optional[str] = None, **kwargs: Any
|
| ) -> None:
|
| if self._current_thought is not None:
|
| self._current_thought.complete(
|
| self._thought_labeler.get_final_agent_thought_label()
|
| .replace("[", "\[")
|
| .replace("]", "\]")
|
| )
|
| self._current_thought = None |