Spaces:
Sleeping
Sleeping
add files
Browse files- Makefile +8 -0
- app/.DS_Store +0 -0
- app/__init__.py +0 -0
- app/config/__pycache__/model_params.cpython-310.pyc +0 -0
- app/config/model_params.py +6 -0
- app/main.py +201 -0
- app/utils/__init__.py +0 -0
- app/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- app/utils/__pycache__/api.cpython-310.pyc +0 -0
- app/utils/__pycache__/classification.cpython-310.pyc +0 -0
- app/utils/__pycache__/evaluation.cpython-310.pyc +0 -0
- app/utils/__pycache__/prompt.cpython-310.pyc +0 -0
- app/utils/__pycache__/tokens.cpython-310.pyc +0 -0
- app/utils/__pycache__/validation.cpython-310.pyc +0 -0
- app/utils/api.py +26 -0
- app/utils/classification.py +26 -0
- app/utils/evaluation.py +21 -0
- app/utils/prompt.py +54 -0
- app/utils/tokens.py +15 -0
- app/utils/validation.py +18 -0
- requirements.txt +8 -0
- tests/__init__.py +0 -0
- tests/__pycache__/__init__.cpython-310.pyc +0 -0
- tests/__pycache__/test_api.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/__pycache__/test_evaluation.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/__pycache__/test_prompt.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/__pycache__/test_validation.cpython-310-pytest-8.3.4.pyc +0 -0
- tests/test_api.py +18 -0
- tests/test_evaluation.py +14 -0
- tests/test_prompt.py +23 -0
- tests/test_validation.py +17 -0
Makefile
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
setup:
|
| 2 |
+
pip install -r requirements.txt
|
| 3 |
+
|
| 4 |
+
run:
|
| 5 |
+
streamlit run app.py
|
| 6 |
+
|
| 7 |
+
test:
|
| 8 |
+
PYTHONPATH=./app pytest
|
app/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app/__init__.py
ADDED
|
File without changes
|
app/config/__pycache__/model_params.cpython-310.pyc
ADDED
|
Binary file (303 Bytes). View file
|
|
|
app/config/model_params.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DEFAULT_PARAMS = {
|
| 2 |
+
"model": "gpt-4o-mini-2024-07-18",
|
| 3 |
+
"max_tokens": 60,
|
| 4 |
+
"temperature": 0.0,
|
| 5 |
+
"available_models": ["gpt-4o-mini-2024-07-18", "gpt-4o-2024-08-06"], # Structured-output-compatible models
|
| 6 |
+
}
|
app/main.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from utils.prompt import generate_prompts
|
| 4 |
+
from utils.classification import apply_classification
|
| 5 |
+
from utils.validation import generate_classification_model
|
| 6 |
+
from utils.api import get_openai_client
|
| 7 |
+
from utils.tokens import estimate_token_count
|
| 8 |
+
from config.model_params import DEFAULT_PARAMS
|
| 9 |
+
|
| 10 |
+
st.set_page_config(layout="wide")
|
| 11 |
+
|
| 12 |
+
# Streamlit App Title
|
| 13 |
+
st.title("LLM-based Classifier")
|
| 14 |
+
|
| 15 |
+
# Upload Dataset
|
| 16 |
+
uploaded_file = st.sidebar.file_uploader("Upload a CSV file", type=["csv"])
|
| 17 |
+
if uploaded_file:
|
| 18 |
+
df = pd.read_csv(uploaded_file)
|
| 19 |
+
st.write("### Data Preview", df.head())
|
| 20 |
+
|
| 21 |
+
# Select Target Column
|
| 22 |
+
label_column = st.selectbox("Select target column (if available):", df.columns.tolist())
|
| 23 |
+
|
| 24 |
+
# Exclude Target Column from Feature Selection
|
| 25 |
+
if label_column: # Ensure the label column is defined
|
| 26 |
+
filtered_columns = [col for col in df.columns if col != label_column]
|
| 27 |
+
else:
|
| 28 |
+
filtered_columns = df.columns.tolist()
|
| 29 |
+
|
| 30 |
+
# Feature Selection
|
| 31 |
+
features = st.multiselect("Select features:", filtered_columns, default=filtered_columns)
|
| 32 |
+
|
| 33 |
+
# Validate Features
|
| 34 |
+
if label_column in features:
|
| 35 |
+
st.error(f"Target column '{label_column}' cannot be included in features. Please remove it.")
|
| 36 |
+
st.stop()
|
| 37 |
+
|
| 38 |
+
# Specify Prediction Column Name
|
| 39 |
+
prediction_column = st.text_input(
|
| 40 |
+
"Enter the name of the column to store predictions:", "Predicted Label"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Define Labels and Descriptions
|
| 44 |
+
st.write(f"### Describe the values {prediction_column} can take")
|
| 45 |
+
num_labels = st.number_input("Number of unique labels:", min_value=2, step=1)
|
| 46 |
+
|
| 47 |
+
# Create columns for labels and descriptions
|
| 48 |
+
col1, col2 = st.columns(2)
|
| 49 |
+
|
| 50 |
+
label_descriptions = {}
|
| 51 |
+
for i in range(int(num_labels)):
|
| 52 |
+
with col1:
|
| 53 |
+
label = st.text_input(f"Label {i+1} name:", key=f"label_name_{i}")
|
| 54 |
+
with col2:
|
| 55 |
+
description = st.text_input(f"Label {i+1} description:", key=f"label_desc_{i}")
|
| 56 |
+
label_descriptions[label] = description
|
| 57 |
+
|
| 58 |
+
# Compare user-provided labels with unique target values
|
| 59 |
+
if label_column:
|
| 60 |
+
# Get unique values in the target column
|
| 61 |
+
unique_target_values = set(df[label_column].unique())
|
| 62 |
+
n_unique_target_values = len(unique_target_values)
|
| 63 |
+
|
| 64 |
+
if n_unique_target_values > 20:
|
| 65 |
+
st.warning(
|
| 66 |
+
f"The selected column '{label_column}' has {n_unique_target_values} unique values, "
|
| 67 |
+
f"which may not be ideal as a target for classification."
|
| 68 |
+
)
|
| 69 |
+
proceed = st.checkbox(
|
| 70 |
+
f"I understand and still want to use '{label_column}' as the target column."
|
| 71 |
+
)
|
| 72 |
+
if not proceed:
|
| 73 |
+
st.stop()
|
| 74 |
+
|
| 75 |
+
# Get user-provided labels
|
| 76 |
+
user_provided_labels = set(label_descriptions.keys())
|
| 77 |
+
|
| 78 |
+
# Identify missing and extra labels
|
| 79 |
+
missing_labels = unique_target_values - user_provided_labels
|
| 80 |
+
extra_labels = user_provided_labels - unique_target_values
|
| 81 |
+
|
| 82 |
+
# Display warnings for discrepancies
|
| 83 |
+
if missing_labels:
|
| 84 |
+
st.warning(
|
| 85 |
+
f"The following values in the target column are not accounted for in the labels: {', '.join(map(str, missing_labels))}."
|
| 86 |
+
)
|
| 87 |
+
if extra_labels:
|
| 88 |
+
st.warning(
|
| 89 |
+
f"The following user-provided labels do not match any values in the target column: {', '.join(map(str, extra_labels))}."
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Few-Shot Prompting
|
| 93 |
+
use_few_shot = st.checkbox("Use few-shot prompting with examples from the target column", value=False)
|
| 94 |
+
|
| 95 |
+
if use_few_shot and label_column:
|
| 96 |
+
st.info("Few-shot prompting is enabled. Examples will be selected from the dataset.")
|
| 97 |
+
|
| 98 |
+
# Group by target column and select 2 examples per class
|
| 99 |
+
few_shot_examples = (
|
| 100 |
+
df.groupby(label_column, group_keys=False)
|
| 101 |
+
.apply(lambda group: group.sample(min(2, len(group)), random_state=42))
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Show the few-shot examples for reference
|
| 105 |
+
st.write("### Few-Shot Examples")
|
| 106 |
+
st.write(few_shot_examples[[*features, label_column]])
|
| 107 |
+
|
| 108 |
+
# Remove few-shot examples from the dataset
|
| 109 |
+
remaining_data = df.drop(few_shot_examples.index)
|
| 110 |
+
else:
|
| 111 |
+
few_shot_examples = None
|
| 112 |
+
remaining_data = df
|
| 113 |
+
|
| 114 |
+
# Limit rows to 20 to control costs
|
| 115 |
+
if len(remaining_data) > 20:
|
| 116 |
+
st.warning("Only the first 20 rows of the remaining dataset will be sent to OpenAI to save costs.")
|
| 117 |
+
|
| 118 |
+
limited_data = remaining_data.head(20)
|
| 119 |
+
|
| 120 |
+
# Prepare Few-Shot Examples for Prompting
|
| 121 |
+
example_rows = []
|
| 122 |
+
if use_few_shot and few_shot_examples is not None:
|
| 123 |
+
for _, example in few_shot_examples.iterrows():
|
| 124 |
+
example_rows.append({
|
| 125 |
+
"features": {feature: example[feature] for feature in features},
|
| 126 |
+
"label": example[label_column],
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
# API Key and Model Parameters
|
| 130 |
+
openai_api_key = st.sidebar.text_input("Enter your OpenAI API Key:", type="password")
|
| 131 |
+
model_params = {
|
| 132 |
+
"model": st.selectbox(
|
| 133 |
+
"Model:",
|
| 134 |
+
DEFAULT_PARAMS["available_models"],
|
| 135 |
+
index=DEFAULT_PARAMS["available_models"].index(DEFAULT_PARAMS["model"])
|
| 136 |
+
),
|
| 137 |
+
"temperature": st.slider("Temperature:", min_value=0.0, max_value=1.0, value=DEFAULT_PARAMS["temperature"]),
|
| 138 |
+
"max_tokens": DEFAULT_PARAMS["max_tokens"],
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
st.sidebar.write('**Model Config**')
|
| 142 |
+
st.sidebar.json(DEFAULT_PARAMS)
|
| 143 |
+
|
| 144 |
+
verbose = st.checkbox("Verbose", value=False)
|
| 145 |
+
|
| 146 |
+
# Classification Button
|
| 147 |
+
if st.button("Run Classification"):
|
| 148 |
+
if not openai_api_key:
|
| 149 |
+
st.error("Please provide a valid OpenAI API Key.")
|
| 150 |
+
else:
|
| 151 |
+
# Initialize OpenAI client
|
| 152 |
+
client = get_openai_client(api_key=openai_api_key)
|
| 153 |
+
|
| 154 |
+
# Dynamically create the Pydantic model for validation
|
| 155 |
+
ClassificationOutput = generate_classification_model(list(label_descriptions.keys()))
|
| 156 |
+
|
| 157 |
+
# Function to classify a single row
|
| 158 |
+
def classify_row(row):
|
| 159 |
+
# Generate system and user prompts
|
| 160 |
+
system_prompt, user_prompt = generate_prompts(
|
| 161 |
+
row=row.to_dict(),
|
| 162 |
+
label_descriptions=label_descriptions,
|
| 163 |
+
features=features,
|
| 164 |
+
example_rows=example_rows,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Show the prompts in an expander for transparency
|
| 168 |
+
if verbose:
|
| 169 |
+
with st.expander(f"OpenAI Call Input for Row Index {row.name}"):
|
| 170 |
+
st.write("**System Prompt:**")
|
| 171 |
+
st.code(system_prompt)
|
| 172 |
+
st.write(f"Token Count (System Prompt): {estimate_token_count(system_prompt, model_params['model'])}")
|
| 173 |
+
st.write("**User Prompt:**")
|
| 174 |
+
st.code(user_prompt)
|
| 175 |
+
st.write(f"Token Count (User Prompt): {estimate_token_count(user_prompt, model_params['model'])}")
|
| 176 |
+
|
| 177 |
+
# Make the OpenAI call and validate the output
|
| 178 |
+
return apply_classification(
|
| 179 |
+
client=client,
|
| 180 |
+
model_params=model_params,
|
| 181 |
+
ClassificationOutput=ClassificationOutput,
|
| 182 |
+
system_prompt=system_prompt,
|
| 183 |
+
user_prompt=user_prompt,
|
| 184 |
+
verbose=verbose,
|
| 185 |
+
st=st
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Apply the classification to each row in the limited data
|
| 189 |
+
limited_data[prediction_column] = limited_data.apply(classify_row, axis=1)
|
| 190 |
+
|
| 191 |
+
# Display Predictions
|
| 192 |
+
st.write(f"### Predictions ({prediction_column})", limited_data)
|
| 193 |
+
|
| 194 |
+
# Evaluate if ground truth is available
|
| 195 |
+
if label_column in limited_data.columns:
|
| 196 |
+
from utils.evaluation import evaluate_predictions
|
| 197 |
+
report = evaluate_predictions(limited_data[label_column], limited_data[prediction_column])
|
| 198 |
+
st.write("### Evaluation Metrics")
|
| 199 |
+
st.json(report)
|
| 200 |
+
else:
|
| 201 |
+
st.warning(f"Target column '{label_column}' or prediction column '{prediction_column}' is missing in the data.")
|
app/utils/__init__.py
ADDED
|
File without changes
|
app/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (150 Bytes). View file
|
|
|
app/utils/__pycache__/api.cpython-310.pyc
ADDED
|
Binary file (1.03 kB). View file
|
|
|
app/utils/__pycache__/classification.cpython-310.pyc
ADDED
|
Binary file (880 Bytes). View file
|
|
|
app/utils/__pycache__/evaluation.cpython-310.pyc
ADDED
|
Binary file (779 Bytes). View file
|
|
|
app/utils/__pycache__/prompt.cpython-310.pyc
ADDED
|
Binary file (2 kB). View file
|
|
|
app/utils/__pycache__/tokens.cpython-310.pyc
ADDED
|
Binary file (656 Bytes). View file
|
|
|
app/utils/__pycache__/validation.cpython-310.pyc
ADDED
|
Binary file (750 Bytes). View file
|
|
|
app/utils/api.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from openai import OpenAI
|
| 2 |
+
|
| 3 |
+
# Initialize OpenAI client
|
| 4 |
+
def get_openai_client(api_key):
|
| 5 |
+
"""
|
| 6 |
+
Returns an OpenAI client instance with the provided API key.
|
| 7 |
+
"""
|
| 8 |
+
return OpenAI(api_key=api_key)
|
| 9 |
+
|
| 10 |
+
def classify_row_chat(prompt, client, model="gpt-3.5-turbo"):
|
| 11 |
+
"""
|
| 12 |
+
Sends a classification prompt to the OpenAI Chat API and returns the predicted label.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
prompt (str): The user prompt to classify data.
|
| 16 |
+
client (OpenAI): The OpenAI client instance.
|
| 17 |
+
model (str): The model to use for chat completion.
|
| 18 |
+
|
| 19 |
+
Returns:
|
| 20 |
+
str: The predicted label.
|
| 21 |
+
"""
|
| 22 |
+
response = client.chat.completions.create(
|
| 23 |
+
model=model,
|
| 24 |
+
messages=[{"role": "user", "content": prompt}]
|
| 25 |
+
)
|
| 26 |
+
return response.choices[0].message.content.strip()
|
app/utils/classification.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def apply_classification(client, model_params, ClassificationOutput, system_prompt, user_prompt, verbose=False, st=None):
|
| 2 |
+
response = client.chat.completions.create(
|
| 3 |
+
model=model_params["model"],
|
| 4 |
+
messages=[
|
| 5 |
+
{"role": "system", "content": system_prompt},
|
| 6 |
+
{"role": "user", "content": user_prompt},
|
| 7 |
+
],
|
| 8 |
+
max_tokens=model_params["max_tokens"],
|
| 9 |
+
temperature=model_params["temperature"],
|
| 10 |
+
)
|
| 11 |
+
raw_prediction = response.choices[0].message.content.strip()
|
| 12 |
+
|
| 13 |
+
# Log raw prediction for debugging
|
| 14 |
+
if verbose and st:
|
| 15 |
+
st.info(f"Raw Prediction: {raw_prediction}")
|
| 16 |
+
|
| 17 |
+
# Validate and process the prediction
|
| 18 |
+
try:
|
| 19 |
+
validated_prediction = ClassificationOutput.parse_obj({"label": raw_prediction}).label
|
| 20 |
+
except Exception as e:
|
| 21 |
+
if verbose and st:
|
| 22 |
+
st.error(f"Invalid prediction: {raw_prediction}. Error: {e}")
|
| 23 |
+
return "INVALID"
|
| 24 |
+
|
| 25 |
+
return validated_prediction
|
| 26 |
+
|
app/utils/evaluation.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
from sklearn.metrics import classification_report
|
| 3 |
+
|
| 4 |
+
def evaluate_predictions(y_true, y_pred):
|
| 5 |
+
"""
|
| 6 |
+
Evaluates predictions by converting labels to strings and generating a classification report.
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
y_true (pd.Series or list): True labels.
|
| 10 |
+
y_pred (pd.Series or list): Predicted labels.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
dict: Classification report as a dictionary.
|
| 14 |
+
"""
|
| 15 |
+
# Ensure both true and predicted labels are strings
|
| 16 |
+
y_true_str = pd.Series(y_true).astype(str)
|
| 17 |
+
y_pred_str = pd.Series(y_pred).astype(str)
|
| 18 |
+
|
| 19 |
+
# Generate classification report
|
| 20 |
+
report = classification_report(y_true_str, y_pred_str, output_dict=True)
|
| 21 |
+
return report
|
app/utils/prompt.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def create_classification_prompt(row, label_descriptions, features, example_rows):
|
| 2 |
+
"""
|
| 3 |
+
Generates system and user prompts for classification.
|
| 4 |
+
|
| 5 |
+
Args:
|
| 6 |
+
row (dict): A single row of feature values.
|
| 7 |
+
label_descriptions (dict): Mapping of labels to their descriptions.
|
| 8 |
+
features (list): List of features to include in the prompt.
|
| 9 |
+
example_rows (list): Few-shot examples for the prompt.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
tuple: (system_prompt, user_prompt)
|
| 13 |
+
"""
|
| 14 |
+
# System prompt
|
| 15 |
+
system_prompt = "You are a classifier. Assign one of the following labels based on the input data:\n"
|
| 16 |
+
for label, desc in label_descriptions.items():
|
| 17 |
+
system_prompt += f"- {label}: {desc}\n"
|
| 18 |
+
|
| 19 |
+
# Few-shot examples
|
| 20 |
+
if example_rows:
|
| 21 |
+
system_prompt += "\nExamples:\n"
|
| 22 |
+
for example in example_rows:
|
| 23 |
+
example_features = "; ".join(
|
| 24 |
+
f"{feature}: {example['features'][feature]}" for feature in features
|
| 25 |
+
#f"{feature}: {example.get('features', {}).get(feature, 'MISSING')}" for feature in features
|
| 26 |
+
)
|
| 27 |
+
system_prompt += f"Input: {example_features}\nLabel: {example['label']}\n"
|
| 28 |
+
|
| 29 |
+
# User prompt for the current row
|
| 30 |
+
user_features = "; ".join(f"{feature}: {row[feature]}" for feature in features)
|
| 31 |
+
user_prompt = f"Input: {user_features}\nLabel:"
|
| 32 |
+
|
| 33 |
+
return system_prompt, user_prompt
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_prompts(row, label_descriptions, features, example_rows):
|
| 37 |
+
"""
|
| 38 |
+
Wrapper for create_classification_prompt to generate prompts for a row.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
row (dict): Row of the dataset.
|
| 42 |
+
label_descriptions (dict): Mapping of labels to their descriptions.
|
| 43 |
+
features (list): List of features to include in the prompt.
|
| 44 |
+
example_rows (list): Few-shot examples for the prompt.
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
tuple: (system_prompt, user_prompt)
|
| 48 |
+
"""
|
| 49 |
+
return create_classification_prompt(
|
| 50 |
+
row=row,
|
| 51 |
+
label_descriptions=label_descriptions,
|
| 52 |
+
features=features,
|
| 53 |
+
example_rows=example_rows,
|
| 54 |
+
)
|
app/utils/tokens.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tiktoken
|
| 2 |
+
|
| 3 |
+
def estimate_token_count(prompt: str, model: str) -> int:
|
| 4 |
+
"""
|
| 5 |
+
Estimate the token count for a given prompt and model.
|
| 6 |
+
|
| 7 |
+
Args:
|
| 8 |
+
prompt (str): The input prompt to tokenize.
|
| 9 |
+
model (str): The name of the model to use for token encoding.
|
| 10 |
+
|
| 11 |
+
Returns:
|
| 12 |
+
int: The estimated token count.
|
| 13 |
+
"""
|
| 14 |
+
encoding = tiktoken.encoding_for_model(model)
|
| 15 |
+
return len(encoding.encode(prompt))
|
app/utils/validation.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, create_model
|
| 2 |
+
from typing import Literal, List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def generate_classification_model(labels: List[str]) -> BaseModel:
|
| 6 |
+
"""
|
| 7 |
+
Dynamically generates a Pydantic model for classification based on user-provided labels.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
labels (List[str]): List of valid label strings.
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
BaseModel: A dynamically generated Pydantic model.
|
| 14 |
+
"""
|
| 15 |
+
return create_model(
|
| 16 |
+
"DynamicClassificationOutput",
|
| 17 |
+
label=(Literal[tuple(labels)], ...), # Enforce that 'label' matches one of the valid labels
|
| 18 |
+
)
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
kagglehub
|
| 2 |
+
pytest
|
| 3 |
+
pytest-mock
|
| 4 |
+
sentencepiece
|
| 5 |
+
sentence_transformers
|
| 6 |
+
streamlit
|
| 7 |
+
tiktoken
|
| 8 |
+
transformers
|
tests/__init__.py
ADDED
|
File without changes
|
tests/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (146 Bytes). View file
|
|
|
tests/__pycache__/test_api.cpython-310-pytest-8.3.4.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
tests/__pycache__/test_evaluation.cpython-310-pytest-8.3.4.pyc
ADDED
|
Binary file (1.21 kB). View file
|
|
|
tests/__pycache__/test_prompt.cpython-310-pytest-8.3.4.pyc
ADDED
|
Binary file (1.49 kB). View file
|
|
|
tests/__pycache__/test_validation.cpython-310-pytest-8.3.4.pyc
ADDED
|
Binary file (1.53 kB). View file
|
|
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from unittest.mock import Mock
|
| 2 |
+
from utils.api import classify_row_chat
|
| 3 |
+
|
| 4 |
+
def test_classify_row_chat():
|
| 5 |
+
# Mock the OpenAI client and its response
|
| 6 |
+
client_mock = Mock()
|
| 7 |
+
client_mock.chat.completions.create.return_value = Mock(
|
| 8 |
+
choices=[Mock(message=Mock(content="Positive"))]
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
# Define the prompt
|
| 12 |
+
prompt = "Classify the following observation: Age: 25, Weight: 70\nLabel:"
|
| 13 |
+
|
| 14 |
+
# Call the classify_row_chat function with the mocked client
|
| 15 |
+
prediction = classify_row_chat(prompt=prompt, client=client_mock, model="gpt-3.5-turbo")
|
| 16 |
+
|
| 17 |
+
# Assert the response matches the expected label
|
| 18 |
+
assert prediction == "Positive", "The classification should return 'Positive'"
|
tests/test_evaluation.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.evaluation import evaluate_predictions
|
| 2 |
+
|
| 3 |
+
def test_evaluate_predictions():
|
| 4 |
+
y_true = ["Positive", "Negative", "Positive"]
|
| 5 |
+
y_pred = ["Positive", "Negative", "Positive"]
|
| 6 |
+
|
| 7 |
+
# Test perfect match
|
| 8 |
+
report = evaluate_predictions(y_true, y_pred)
|
| 9 |
+
assert report["accuracy"] == 1.0, "Accuracy should be 100% for perfect predictions"
|
| 10 |
+
|
| 11 |
+
# Test mismatched predictions
|
| 12 |
+
y_pred_mismatch = ["Negative", "Negative", "Positive"]
|
| 13 |
+
report_mismatch = evaluate_predictions(y_true, y_pred_mismatch)
|
| 14 |
+
assert report_mismatch["accuracy"] < 1.0, "Accuracy should be less than 100% for mismatched predictions"
|
tests/test_prompt.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
from utils.prompt import generate_prompts
|
| 3 |
+
|
| 4 |
+
def test_generate_prompts():
|
| 5 |
+
example_rows = [ # Update to match the function's parameter name
|
| 6 |
+
{"features": {"Age": 34, "Weight": 70, "Location": "Urban"}, "label": "Positive"},
|
| 7 |
+
{"features": {"Age": 25, "Weight": 60, "Location": "Rural"}, "label": "Negative"},
|
| 8 |
+
]
|
| 9 |
+
features = ["Age", "Weight", "Location"]
|
| 10 |
+
label_descriptions = {
|
| 11 |
+
"Positive": "The sentiment is positive.",
|
| 12 |
+
"Negative": "The sentiment is negative.",
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
row = {"Age": 30, "Weight": 65, "Location": "Suburban"}
|
| 16 |
+
|
| 17 |
+
system_prompt, user_prompt = generate_prompts(
|
| 18 |
+
row=row, example_rows=example_rows, features=features, label_descriptions=label_descriptions
|
| 19 |
+
)
|
| 20 |
+
assert "Age: 34; Weight: 70; Location: Urban" in system_prompt
|
| 21 |
+
assert "Label: Positive" in system_prompt
|
| 22 |
+
assert "Label:" in user_prompt
|
| 23 |
+
|
tests/test_validation.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import ValidationError
|
| 2 |
+
from utils.validation import generate_classification_model
|
| 3 |
+
|
| 4 |
+
def test_classification_output_validation():
|
| 5 |
+
# Dynamically generate classification model
|
| 6 |
+
ClassificationOutput = generate_classification_model(["Positive", "Negative"])
|
| 7 |
+
|
| 8 |
+
# Test valid input
|
| 9 |
+
valid_output = ClassificationOutput(label="Positive")
|
| 10 |
+
assert valid_output.label == "Positive", "Label should be 'Positive'"
|
| 11 |
+
|
| 12 |
+
# Test invalid input
|
| 13 |
+
try:
|
| 14 |
+
ClassificationOutput(label="InvalidLabel")
|
| 15 |
+
except ValidationError as e:
|
| 16 |
+
error_message = str(e)
|
| 17 |
+
assert "Input should be 'Positive' or 'Negative'" in error_message, "Should raise validation error with correct message"
|