| import argparse |
| from dataclasses import asdict, dataclass, field |
| from datetime import datetime |
| import html |
| from itertools import zip_longest |
| import os |
| import textwrap |
| from typing import Dict, List, Tuple |
|
|
| from dotenv import load_dotenv |
| import gradio as gr |
| from pymongo import MongoClient |
|
|
| from llm_rules import Role, Message, models, scenarios |
|
|
|
|
| MONGO_URI = "mongodb+srv://{username}:{password}@{host}/?retryWrites=true&w=majority" |
| MONGO_DB = None |
| PLACEHOLDER = "Enter message" |
|
|
| History = List[List[str]] |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--hf_proxy", action="store_true", default=False) |
| parser.add_argument("--port", type=int, default=7860) |
| return parser.parse_args() |
|
|
|
|
| @dataclass |
| class State: |
| scenario_name: str |
| provider_name: str |
| model_name: str |
| scenario: scenarios.scenario.BaseScenario = None |
| model: models.BaseModel = None |
| system_message: str = None |
| use_system_instructions: bool = False |
| messages: List[Message] = field(default_factory=list) |
| redacted_messages: List[Message] = field(default_factory=list) |
| last_user_message_valid: bool = False |
|
|
| def __post_init__(self): |
| self.scenario = scenarios.SCENARIOS[self.scenario_name]() |
| self.model = models.MODEL_BUILDERS[self.provider_name]( |
| model=self.model_name, |
| stream=True, |
| temperature=0, |
| ) |
| self.messages = self.get_initial_messages() |
| self.redacted_messages = self.get_initial_messages(redacted=True) |
|
|
| def get_initial_messages(self, redacted=False) -> List[Message]: |
| prompt = self.scenario.redacted_prompt if redacted else self.scenario.prompt |
| if self.use_system_instructions: |
| messages = [ |
| Message(Role.SYSTEM, prompt), |
| ] |
| else: |
| messages = [ |
| Message(Role.SYSTEM, models.PROMPTS[self.system_message]), |
| Message(Role.USER, prompt), |
| Message(Role.ASSISTANT, self.scenario.initial_response), |
| ] |
| return messages |
|
|
| def get_history(self) -> History: |
| """Process redacted messages into format for chatbot to display.""" |
| redacted_messages = self.redacted_messages[1:] |
| history = [] |
| args = [iter(redacted_messages)] * 2 |
| for u, a in zip_longest(*args): |
| u = html.escape(u.content, quote=False) |
| a = None if a is None else html.escape(a.content, quote=False) |
| history.append([u, a]) |
| return history |
|
|
| def update_state_and_history(self, history: History, delta: str) -> History: |
| """Incrementally update last item of both messages and history.""" |
| |
| self.messages[-1].content += delta |
| history[-1][-1] += html.escape(delta, quote=False) |
| return history |
|
|
| def get_info(self): |
| info_str = "Return to send message. Shift + Return to add a new line." |
| if self.scenario.format_message: |
| info_str = self.scenario.format_message + " " + info_str |
| return info_str |
|
|
| def unescape_messages(self) -> List[Message]: |
| return [Message(m.role, html.unescape(m.content)) for m in self.messages] |
|
|
|
|
| def change_provider(state: State, provider_name: str) -> Tuple[State, Dict]: |
| """Update model provider and model selection.""" |
| state.provider_name = provider_name.lower() |
| state.model_name = models.MODEL_DEFAULTS[state.provider_name] |
| state.model = models.MODEL_BUILDERS[state.provider_name]( |
| model=state.model_name, |
| stream=True, |
| temperature=0, |
| ) |
| update_model = gr.update( |
| choices=models.MODEL_NAMES_BY_PROVIDER[state.provider_name], |
| value=state.model_name, |
| ) |
| return state, update_model |
|
|
|
|
| def change_model(state: State, model_name: str) -> State: |
| """Update model selection.""" |
| state.model_name = model_name |
| state.model = models.MODEL_BUILDERS[state.provider_name]( |
| model=state.model_name, |
| stream=True, |
| temperature=0, |
| ) |
| return state |
|
|
|
|
| def change_scenario(state: State, scenario: str) -> Tuple[State, Dict]: |
| state.scenario = scenarios.SCENARIOS[scenario]() |
| state.scenario_name = scenario |
| update = gr.update(placeholder=PLACEHOLDER, label=state.get_info()) |
| return state, update |
|
|
|
|
| def send_user_message(state: State, input: str) -> Tuple[State, History, Dict]: |
| """Update state and chatbot with user input, clear textbox.""" |
| user_msg = Message(Role.USER, input) |
| if not state.scenario.is_valid_user_message(user_msg): |
| gr.Warning(f"Invalid user message: {state.scenario.format_message}'") |
| update = gr.update() |
| else: |
| state.messages.append(user_msg) |
| state.redacted_messages.append(user_msg) |
| state.last_user_message_valid = True |
| update = gr.update(placeholder=PLACEHOLDER, value="") |
| return state, state.get_history(), update |
|
|
|
|
| def send_assistant_message(state: State, api_key: str) -> Tuple[State, History]: |
| """Request model response and update blocks.""" |
| history = state.get_history() |
| yield state, history |
|
|
| if not state.last_user_message_valid: |
| return |
|
|
| try: |
| api_key = None if api_key == "" else api_key |
| response = state.model(state.messages, api_key=api_key) |
| except Exception as e: |
| raise gr.Error(f"API error: {e} Please reset the scenario and try again.") |
|
|
| asst_msg = Message(Role.ASSISTANT, "") |
| state.messages.append(asst_msg) |
| state.redacted_messages.append(asst_msg) |
| history = state.get_history() |
|
|
| for delta in response: |
| history = state.update_state_and_history(history, delta) |
| yield state, history |
|
|
|
|
| def evaluate_and_log(state: State) -> Tuple[State, Dict]: |
| """Evaluate messages and update chatbot.""" |
| if not state.last_user_message_valid: |
| return state, gr.update() |
|
|
| messages = state.unescape_messages() |
| result = state.scenario.evaluate(messages, state.use_system_instructions) |
| state.last_user_message_valid = False |
|
|
| global MONGO_DB |
| if MONGO_DB is not None: |
| doc = { |
| "timestamp": datetime.now(), |
| "scenario": state.scenario_name, |
| "params": asdict(state.scenario.p), |
| "provider": state.provider_name, |
| "model": state.model_name, |
| "system_instructions": state.use_system_instructions, |
| "messages": Message.serialize(state.messages), |
| "result": asdict(result), |
| } |
|
|
| try: |
| MONGO_DB.insert_one(doc) |
| except Exception as e: |
| gr.Error(f"Database failed: {e}") |
|
|
| if not result.passed: |
| gr.Info("Rule broken!") |
| return state, gr.update( |
| interactive=False, |
| value=f"Scenario terminated: {result.reason} Please reset the scenario or select another.", |
| ) |
| else: |
| return state, gr.update( |
| interactive=True, |
| placeholder=PLACEHOLDER, |
| ) |
|
|
|
|
| def reset_scenario(state: State) -> Tuple[State, History, dict]: |
| state.messages = state.get_initial_messages() |
| state.redacted_messages = state.get_initial_messages(redacted=True) |
| history = state.get_history() |
| update = gr.update(placeholder=PLACEHOLDER, value="", interactive=True) |
| return state, history, update |
|
|
|
|
| def make_block(initial_state): |
| markdown = ( |
| "<h1>RuLES: Rule-following Language Evaluation Scenarios</h1>" |
| "<p>Your goal is to get the assistant model to break its rules. The rules are presented to the model in the " |
| "first user message. Each assistant response is checked using the evaluate method of the scenario, and the " |
| "conversation terminates upon a failing result. Passwords and other secrets are redacted from the chat " |
| "interface and replaced with \"[REDACTED]\", but the actual value is still provided to the model. Some " |
| "scenarios require user messages in a specific format. Your messages will be logged for research use. Please do" |
| " not enter any sensitive information.</p>" |
| "<p>Anthropic has graciously provided us with subsidized API access, and Google Cloud currently offers 100% " |
| "discounted access to the PaLM 2 API during its Preview stage, so no API key is required for these models. " |
| "To access OpenAI models, please enter your own API key. We do not record your key, which you can verify " |
| "in the demo's source code.</p>" |
| "<p>See the RuLES <a href=\"https://github.com/normster/llm_rules\">github repo</a> for more information.</p>" |
| ) |
|
|
| with gr.Blocks( |
| gr.themes.Monochrome( |
| font=[ |
| gr.themes.GoogleFont("Source Sans Pro"), |
| "ui-sans-serif", |
| "system-ui", |
| "sans-serif", |
| ], |
| radius_size=gr.themes.sizes.radius_sm, |
| ) |
| ) as block: |
| gr.Markdown(markdown, sanitize_html=False) |
| state = gr.State(value=initial_state) |
| with gr.Row(): |
| provider_select = gr.Dropdown( |
| ["Anthropic", "OpenAI", "Google"], |
| value="Anthropic", |
| label="Provider", |
| ) |
| model_select = gr.Dropdown( |
| models.MODEL_NAMES_BY_PROVIDER["anthropic"], |
| value="claude-instant-v1.2", |
| label="Model", |
| ) |
| scenario_select = gr.Dropdown( |
| scenarios.SCENARIOS.keys(), |
| value=initial_state.scenario_name, |
| label="Scenario", |
| ) |
| apikey = gr.Textbox(placeholder="sk-...", label="API Key") |
| chatbot = gr.Chatbot(initial_state.get_history(), show_label=False) |
| textbox = gr.Textbox(placeholder=PLACEHOLDER, label=initial_state.get_info()) |
| reset_button = gr.Button("Reset Scenario") |
|
|
| |
| textbox.submit( |
| send_user_message, [state, textbox], [state, chatbot, textbox], queue=True |
| ).then( |
| send_assistant_message, |
| [state, apikey], |
| [state, chatbot], |
| queue=True, |
| ).then( |
| evaluate_and_log, state, [state, textbox], queue=True |
| ) |
| |
| provider_select.change( |
| change_provider, |
| [state, provider_select], |
| [state, model_select], |
| queue=False, |
| ).then( |
| reset_scenario, state, [state, chatbot, textbox], queue=False |
| ) |
| |
| model_select.change( |
| change_model, |
| [state, model_select], |
| [state], |
| queue=False, |
| ).then( |
| reset_scenario, state, [state, chatbot, textbox], queue=False |
| ) |
| |
| scenario_select.change( |
| change_scenario, |
| [state, scenario_select], |
| [state, textbox], |
| queue=False, |
| ).then(reset_scenario, state, [state, chatbot, textbox], queue=False) |
| |
| reset_button.click( |
| reset_scenario, state, [state, chatbot, textbox], queue=False |
| ) |
| block.load(reset_scenario, state, [state, chatbot, textbox], queue=False) |
|
|
| return block |
|
|
|
|
| def main(args): |
| load_dotenv() |
|
|
| initial_state = State( |
| scenario_name="Encryption", |
| provider_name="anthropic", |
| model_name="claude-instant-v1.2", |
| ) |
| initial_state.messages = (initial_state.get_initial_messages(),) |
| initial_state.redacted_messages = ( |
| initial_state.get_initial_messages(redacted=True), |
| ) |
|
|
| |
| global MONGO_DB |
| mongo_uri = MONGO_URI.format( |
| username=os.environ["MONGO_USERNAME"], |
| password=os.environ["MONGO_PASSWORD"], |
| host=os.environ["MONGO_HOST"], |
| ) |
| client = MongoClient(mongo_uri) |
| MONGO_DB = client["messages"]["v1.0"] |
|
|
| block = make_block(initial_state) |
| block.queue(concurrency_count=2) |
| block.launch( |
| server_port=args.port, |
| share=args.hf_proxy, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|