Rishi2455 commited on
Commit
1ef6c6d
ยท
verified ยท
1 Parent(s): 72915e4

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +36 -0
  2. app.py +193 -0
  3. requirements.txt +5 -0
README.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Assistant End-Turn Detector
3
+ emoji: ๐Ÿ—ฃ๏ธ
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 6.13.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # ๐Ÿ—ฃ๏ธ Assistant End-Turn Detector
13
+
14
+ This Hugging Face Space hosts a **Sequence Classification** model designed to detect when a user wants an AI assistant to **STOP** or **CONTINUE** speaking during a real-time conversation.
15
+
16
+ ## ๐Ÿš€ Overview
17
+
18
+ - **Architecture:** `LlamaForSequenceClassification`
19
+ - **Base Model:** [SmolLM2-135M-Instruct](https://huggingface.co/HuggingFaceTB/SmolLM2-135M-Instruct)
20
+ - **Task:** Binary classification (0: CONTINUE, 1: STOP)
21
+ - **Use Case:** Real-time turn-taking and interruption handling for voice/chat bots.
22
+
23
+ ## ๐Ÿ› ๏ธ Implementation
24
+
25
+ The model was fine-tuned on custom dialogue datasets where users either provide back-channeling (encouraging continuation) or interruptions (asking questions, changing topics).
26
+
27
+ ### How to use locally
28
+
29
+ 1. Install dependencies:
30
+ ```bash
31
+ pip install -r requirements.txt
32
+ ```
33
+ 2. Run the application:
34
+ ```bash
35
+ python app.py
36
+ ```
app.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import os
5
+ import time
6
+
7
+ # --- Configuration ---
8
+ # 1. Update this with your actual Hugging Face Repository ID
9
+ MODEL_ID = "Rishi2455/Assistant-End-Turn"
10
+ # 2. Map of predicted IDs to human-readable labels
11
+ LABEL_MAP = {0: "CONTINUE โœ…", 1: "STOP ๐Ÿ›‘"}
12
+
13
+ # --- Model Loading ---
14
+ def load_model():
15
+ print(f"๐Ÿš€ Loading model: {MODEL_ID}")
16
+ try:
17
+ # Priority 1: Check if model files are in the same directory (Space upload)
18
+ if os.path.exists("./config.json"):
19
+ path = "./"
20
+ # Priority 2: Check standard local path
21
+ elif os.path.exists("./models/ETDv8"):
22
+ path = "./models/ETDv8"
23
+ # Priority 3: Fetch from Hugging Face Hub
24
+ else:
25
+ path = MODEL_ID
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(path)
28
+
29
+ # Determine device and torch dtype
30
+ if torch.cuda.is_available():
31
+ device = "cuda"
32
+ dtype = torch.bfloat16
33
+ elif torch.backends.mps.is_available():
34
+ device = "mps"
35
+ dtype = torch.float16
36
+ else:
37
+ device = "cpu"
38
+ dtype = torch.float32
39
+
40
+ model = AutoModelForSequenceClassification.from_pretrained(
41
+ path, torch_dtype=dtype
42
+ ).to(device)
43
+ model.eval()
44
+
45
+ return model, tokenizer, device, None
46
+ except Exception as e:
47
+ return None, None, "cpu", str(e)
48
+
49
+ # Global model data for lazy loading inside Gradio
50
+ model_data = {"model": None, "tokenizer": None, "device": "cpu", "error": None}
51
+
52
+ def get_model():
53
+ if model_data["model"] is None:
54
+ model_data["model"], model_data["tokenizer"], model_data["device"], model_data["error"] = load_model()
55
+ return model_data["model"], model_data["tokenizer"], model_data["device"], model_data["error"]
56
+
57
+ # --- Inference Logic ---
58
+ def detect_turn_end(history):
59
+ if not history or history[-1]["role"] != "user":
60
+ return "<div style='color: #64748b; text-align: center; padding: 20px;'>Last message should be from user</div>"
61
+
62
+ model, tokenizer, device, error = get_model()
63
+ if model is None:
64
+ return f"<div style='color: #ef4444; padding: 10px; border: 1px solid #ef4444; border-radius: 5px;'><b>โŒ Model Error:</b> {error if error else 'Unknown error'}</div>"
65
+
66
+ # 1. Prepare Dialogue History
67
+ dialogue = history.copy()
68
+ if not any(m["role"] == "system" for m in dialogue):
69
+ dialogue.insert(0, {"role": "system", "content": "You are a helpful AI assistant named SmolLM, trained by Hugging Face"})
70
+
71
+ # 2. Apply Chat Template
72
+ try:
73
+ input_text = tokenizer.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=False)
74
+ except:
75
+ # Manual fallback if template missing
76
+ input_text = "".join([f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>\n" for m in dialogue])
77
+
78
+ # 3. Predict
79
+ tokens = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024).to(device)
80
+
81
+ start_time = time.time()
82
+ with torch.no_grad():
83
+ outputs = model(**tokens)
84
+ latency = (time.time() - start_time) * 1000
85
+
86
+ # Get highest confidence prediction
87
+ probs = torch.softmax(outputs.logits, dim=-1).squeeze()
88
+ pred_idx = torch.argmax(probs).item()
89
+ confidence = probs[pred_idx].item()
90
+
91
+ label = LABEL_MAP.get(pred_idx, "UNKNOWN")
92
+ color = "#10b981" if pred_idx == 0 else "#ef4444"
93
+ bg_color = "rgba(16, 185, 129, 0.15)" if pred_idx == 0 else "rgba(239, 68, 68, 0.15)"
94
+
95
+ result_html = f"""
96
+ <div style="padding: 24px; border-radius: 12px; background-color: {bg_color}; border: 2px solid {color}; backdrop-filter: blur(8px);">
97
+ <div style="display: flex; justify-content: space-between; align-items: center;">
98
+ <h1 style="margin: 0; color: white; font-size: 2em; letter-spacing: 1px;">{label}</h1>
99
+ <div style="text-align: right;">
100
+ <p style="margin: 0; color: #94a3b8; font-size: 0.9em;">CONFIDENCE</p>
101
+ <b style="color: {color}; font-size: 1.4em;">{confidence:.2%}</b>
102
+ </div>
103
+ </div>
104
+ <div style="margin-top: 15px; padding-top: 15px; border-top: 1px solid rgba(255,255,255,0.1); display: flex; gap: 20px;">
105
+ <p style="margin: 0; color: #cbd5e1; font-size: 0.85em;">Latency: <b>{latency:.1f}ms</b></p>
106
+ <p style="margin: 0; color: #cbd5e1; font-size: 0.85em;">Device: <b>{device.upper()}</b></p>
107
+ </div>
108
+ </div>
109
+ """
110
+ return result_html
111
+
112
+ # --- UI Layout ---
113
+ with gr.Blocks(theme=gr.themes.Default(primary_hue="indigo", font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui"])) as demo:
114
+ with gr.Column(elem_id="container"):
115
+ gr.Markdown("# ๐Ÿค– AI Turn-End Detector")
116
+ gr.Markdown("""
117
+ Predict if the AI assistant should **STOP** or **CONTINUE** speaking based on the latest user interaction.
118
+ Built with SmolLM2-135M and Fine-tuned for real-time turn detection.
119
+ """)
120
+
121
+ with gr.Row():
122
+ with gr.Column(scale=2):
123
+ chat = gr.Chatbot(type="messages", label="Dialogue Stream", height=450, bubble_full_width=False, show_label=False)
124
+ with gr.Row():
125
+ txt = gr.Textbox(
126
+ label="User Input",
127
+ placeholder="Type a message or an interruption...",
128
+ scale=9,
129
+ container=False
130
+ )
131
+ btn = gr.Button("๐Ÿ”ฎ Predict", variant="primary", scale=1)
132
+
133
+ with gr.Row():
134
+ clear = gr.Button("๐Ÿ—‘๏ธ Clear Context")
135
+ undo = gr.Button("๐Ÿ”™ Undo Last")
136
+
137
+ with gr.Column(scale=1):
138
+ gr.Markdown("### ๐Ÿ” Model Decision")
139
+ status_box = gr.HTML("<div style='height: 150px; display: flex; align-items: center; justify-content: center; border: 2px dashed #334155; border-radius: 12px; color: #64748b; text-align: center;'>Send a message to see the model's prediction</div>")
140
+
141
+ with gr.Accordion("Technical Details", open=True):
142
+ gr.Markdown(f"""
143
+ - **Architecture:** Llama-based Sequence Classification
144
+ - **Base Model:** SmolLM2-135M-Instruct
145
+ - **Target:** Real-time Interruption Detection
146
+ - **HF Repo:** `{MODEL_ID}`
147
+ """)
148
+
149
+ gr.Examples(
150
+ examples=[
151
+ ["Can you please...", "User stops mid-sentence (interruption)"],
152
+ ["Yes, tell me more.", "Positive feedback (continue)"],
153
+ ["Wait, I didn't get that part.", "Question (stop)"],
154
+ ["Okay.", "Short affirmative (stop)"]
155
+ ],
156
+ inputs=[txt]
157
+ )
158
+
159
+ # Logic
160
+ def user_action(message, history):
161
+ if not message.strip():
162
+ return "", history
163
+ history.append({"role": "user", "content": message})
164
+ return "", history
165
+
166
+ def perform_inference(history):
167
+ return detect_turn_end(history)
168
+
169
+ # Trigger Chain
170
+ txt.submit(user_action, [txt, chat], [txt, chat]).then(
171
+ perform_inference, [chat], [status_box]
172
+ )
173
+ btn.click(user_action, [txt, chat], [txt, chat]).then(
174
+ perform_inference, [chat], [status_box]
175
+ )
176
+
177
+ clear.click(lambda: ([], "<div style='height: 150px; display: flex; align-items: center; justify-content: center; border: 2px dashed #334155; border-radius: 12px; color: #64748b; text-align: center;'>History Cleared</div>"), None, [chat, status_box])
178
+ undo.click(lambda h: h[:-1] if h else [], [chat], [chat]).then(
179
+ perform_inference, [chat], [status_box]
180
+ )
181
+
182
+ # Custom Premium Styling
183
+ demo.css = """
184
+ body { background-color: #0f172a !important; color: #f8fafc !important; }
185
+ #container { max-width: 1100px; margin: auto; padding: 20px; }
186
+ .gr-chatbot { border-radius: 12px !important; border: 1px solid #1e293b !important; background-color: #1e293b !important; }
187
+ .message-row { transition: all 0.2s ease-in-out; }
188
+ .message-row:hover { transform: scale(1.01); }
189
+ footer { display: none !important; }
190
+ """
191
+
192
+ if __name__ == "__main__":
193
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ gradio
4
+ accelerate
5
+ numpy