| """ |
| Utility classes and functions for the GuardBench Leaderboard display. |
| """ |
|
|
| from dataclasses import dataclass, field, fields |
| from enum import Enum, auto |
| from typing import List, Optional |
|
|
|
|
| class ModelType(Enum): |
| """Model types for the leaderboard.""" |
| Unknown = auto() |
| OpenSource = auto() |
| ClosedSource = auto() |
| API = auto() |
|
|
| def to_str(self, separator: str = " ") -> str: |
| """Convert enum to string with separator.""" |
| if self == ModelType.Unknown: |
| return "Unknown" |
| elif self == ModelType.OpenSource: |
| return f"Open{separator}Source" |
| elif self == ModelType.ClosedSource: |
| return f"Closed{separator}Source" |
| elif self == ModelType.API: |
| return "API" |
| return "Unknown" |
|
|
| class GuardModelType(str, Enum): |
| """Guard model types for the leaderboard.""" |
| LLAMA_GUARD = "llama_guard" |
| PROMPT_GUARD_CLF = "prompt_guard_clf" |
| ATLA_SELENE = "atla_selene" |
| GEMMA_SHIELD = "gemma_shield" |
| LLM_REGEXP = "llm_regexp" |
| LLM_SO = "llm_so" |
|
|
| def __str__(self): |
| """String representation of the guard model type.""" |
| return self.name |
|
|
|
|
|
|
| class Precision(Enum): |
| """Model precision types.""" |
| Unknown = auto() |
| float16 = auto() |
| bfloat16 = auto() |
| float32 = auto() |
| int8 = auto() |
| int4 = auto() |
|
|
| def __str__(self): |
| """String representation of the precision type.""" |
| return self.name |
|
|
|
|
| class WeightType(Enum): |
| """Model weight types.""" |
| Original = auto() |
| Delta = auto() |
| Adapter = auto() |
| def __str__(self): |
| """String representation of the weight type.""" |
| return self.name |
|
|
|
|
| @dataclass |
| class ColumnInfo: |
| """Information about a column in the leaderboard.""" |
| name: str |
| display_name: str |
| type: str = "text" |
| hidden: bool = False |
| never_hidden: bool = False |
| displayed_by_default: bool = True |
|
|
|
|
| @dataclass |
| class GuardBenchColumn: |
| """Columns for the GuardBench leaderboard.""" |
| |
| model_name: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="model_name", |
| display_name="Model", |
| never_hidden=True, |
| displayed_by_default=True |
| )) |
| model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="model_type", |
| display_name="Type", |
| displayed_by_default=True |
| )) |
| submission_date: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="submission_date", |
| display_name="Submission Date", |
| displayed_by_default=False |
| )) |
| version: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="version", |
| display_name="Version", |
| displayed_by_default=False |
| )) |
| guard_model_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="guard_model_type", |
| display_name="Guard Model Type", |
| displayed_by_default=True |
| )) |
| base_model: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="base_model", |
| display_name="Base Model", |
| displayed_by_default=False |
| )) |
| revision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="revision", |
| display_name="Revision", |
| displayed_by_default=False |
| )) |
| precision: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="precision", |
| display_name="Precision", |
| displayed_by_default=False |
| )) |
| weight_type: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="weight_type", |
| display_name="Weight Type", |
| displayed_by_default=False |
| )) |
|
|
| |
| default_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_f1_binary", |
| display_name="Default Prompts F1 Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_f1", |
| display_name="Default Prompts F1", |
| type="number", |
| displayed_by_default=True |
| )) |
| default_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_recall_binary", |
| display_name="Default Prompts Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_precision_binary", |
| display_name="Default Prompts Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_error_ratio", |
| display_name="Default Prompts Error Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_prompts_avg_runtime_ms", |
| display_name="Default Prompts Avg Runtime (ms)", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| jailbreaked_prompts_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_f1_binary", |
| display_name="Jailbreaked Prompts F1 Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_f1", |
| display_name="Jailbreaked Prompts F1", |
| type="number", |
| displayed_by_default=True |
| )) |
| jailbreaked_prompts_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_recall_binary", |
| display_name="Jailbreaked Prompts Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_precision_binary", |
| display_name="Jailbreaked Prompts Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_error_ratio", |
| display_name="Jailbreaked Prompts Error Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_prompts_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_prompts_avg_runtime_ms", |
| display_name="Jailbreaked Prompts Avg Runtime (ms)", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| default_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_f1_binary", |
| display_name="Default Answers F1 Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_f1", |
| display_name="Default Answers F1", |
| type="number", |
| displayed_by_default=True |
| )) |
| default_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_recall_binary", |
| display_name="Default Answers Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_precision_binary", |
| display_name="Default Answers Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_error_ratio", |
| display_name="Default Answers Error Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| default_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="default_answers_avg_runtime_ms", |
| display_name="Default Answers Avg Runtime (ms)", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
| |
| jailbreaked_answers_f1_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_f1_binary", |
| display_name="Jailbreaked Answers F1 Binary", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_f1: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_f1", |
| display_name="Jailbreaked Answers F1", |
| type="number", |
| displayed_by_default=True |
| )) |
| jailbreaked_answers_recall_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_recall_binary", |
| display_name="Jailbreaked Answers Recall", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_precision_binary: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_precision_binary", |
| display_name="Jailbreaked Answers Precision", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_error_ratio: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_error_ratio", |
| display_name="Jailbreaked Answers Error Ratio", |
| type="number", |
| displayed_by_default=False |
| )) |
| jailbreaked_answers_avg_runtime_ms: ColumnInfo = field(default_factory=lambda: ColumnInfo( |
| name="jailbreaked_answers_avg_runtime_ms", |
| display_name="Jailbreaked Answers Avg Runtime (ms)", |
| type="number", |
| displayed_by_default=False |
| )) |
|
|
|
|
| |
| GUARDBENCH_COLUMN = GuardBenchColumn() |
|
|
| |
| COLS = [f.name for f in fields(GUARDBENCH_COLUMN)] |
| DISPLAY_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] |
| METRIC_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).type == "number"] |
| HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).hidden] |
| NEVER_HIDDEN_COLS = [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).never_hidden] |
|
|
| |
| CATEGORIES = [ |
| "Criminal, Violent, and Terrorist Activity", |
| "Manipulation, Deception, and Misinformation", |
| "Creative Content Involving Illicit Themes", |
| "Sexual Content and Violence", |
| "Political Corruption and Legal Evasion", |
| "Labor Exploitation and Human Trafficking", |
| "Environmental and Industrial Harm", |
| "Animal Cruelty and Exploitation", |
| "Self–Harm and Suicidal Ideation", |
| "Safe Prompts" |
| ] |
|
|
| |
| TEST_TYPES = [ |
| "default_prompts", |
| "jailbreaked_prompts", |
| "default_answers", |
| "jailbreaked_answers" |
| ] |
|
|
| |
| METRICS = [ |
| "f1_binary", |
| "recall_binary", |
| "precision_binary", |
| "error_ratio", |
| "avg_runtime_ms" |
| ] |
|
|
| def get_all_column_choices(): |
| """ |
| Get all available column choices for the multiselect dropdown. |
| |
| Returns: |
| List of tuples with (column_name, display_name) for all columns. |
| """ |
| column_choices = [] |
|
|
| default_visible_columns = get_default_visible_columns() |
|
|
| for f in fields(GUARDBENCH_COLUMN): |
| column_info = getattr(GUARDBENCH_COLUMN, f.name) |
| |
| if column_info.name not in default_visible_columns: |
| column_choices.append((column_info.name, column_info.display_name)) |
|
|
| return column_choices |
|
|
| def get_default_visible_columns(): |
| """ |
| Get the list of column names that should be visible by default. |
| |
| Returns: |
| List of column names that are displayed by default. |
| """ |
| return [getattr(GUARDBENCH_COLUMN, f.name).name for f in fields(GUARDBENCH_COLUMN) |
| if getattr(GUARDBENCH_COLUMN, f.name).displayed_by_default] |
|
|