| """ |
| Cognitive Proxy - Brain-Steered Language Model |
| Hugging Face Spaces deployment |
| Author: Sandro Andric |
| """ |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import pickle |
| import os |
| from pathlib import Path |
| from sklearn.decomposition import PCA |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import plotly.graph_objects as go |
| import plotly.express as px |
| import spaces |
|
|
| |
| import os |
| from pathlib import Path |
|
|
| |
| SCRIPT_DIR = Path(__file__).parent if __file__ else Path.cwd() |
|
|
| |
| if (SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl").exists(): |
| ATLAS_PATH = str(SCRIPT_DIR / "results" / "final_atlas_256_vocab.pkl") |
| ADAPTER_PATH = str(SCRIPT_DIR / "results" / "tinyllama_adapter_direct.pt") |
| elif (SCRIPT_DIR / "final_atlas_256_vocab.pkl").exists(): |
| ATLAS_PATH = str(SCRIPT_DIR / "final_atlas_256_vocab.pkl") |
| ADAPTER_PATH = str(SCRIPT_DIR / "tinyllama_adapter_direct.pt") |
| else: |
| |
| ATLAS_PATH = "results/final_atlas_256_vocab.pkl" |
| ADAPTER_PATH = "results/tinyllama_adapter_direct.pt" |
|
|
| print(f"Atlas path: {ATLAS_PATH}") |
| print(f"Adapter path: {ADAPTER_PATH}") |
|
|
| MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
| |
| class TinyLlamaAdapterDirect(nn.Module): |
| def __init__(self, input_dim=2048, hidden_dim=1024, output_dim=65536): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.LayerNorm(hidden_dim), |
| nn.GELU(), |
| nn.Dropout(0.1), |
| nn.Linear(hidden_dim, hidden_dim // 2), |
| nn.LayerNorm(hidden_dim // 2), |
| nn.GELU(), |
| nn.Linear(hidden_dim // 2, output_dim), |
| ) |
|
|
| def forward(self, x): |
| return self.net(x) |
|
|
| |
| system = None |
|
|
| def load_system(): |
| global system |
| if system is not None: |
| return system |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| try: |
| |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=dtype).to(device) |
| except TypeError: |
| |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=dtype).to(device) |
| model.eval() |
|
|
| adapter = TinyLlamaAdapterDirect().to(device).to(dtype) |
| if os.path.exists(ADAPTER_PATH): |
| adapter.load_state_dict(torch.load(ADAPTER_PATH, map_location=device, weights_only=True)) |
| adapter.eval() |
|
|
| if os.path.exists(ATLAS_PATH): |
| print(f"Loading atlas from {ATLAS_PATH}") |
| with open(ATLAS_PATH, 'rb') as f: |
| data = pickle.load(f) |
| if isinstance(data, dict): |
| print(f"Atlas data keys: {list(data.keys())[:5]}") |
| if 'means' in data: |
| atlas = data['means'] |
| print(f"Using 'means' key, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") |
| else: |
| atlas = data |
| print(f"Using data directly, got {len(atlas) if isinstance(atlas, dict) else 'not a dict'} items") |
| else: |
| atlas = data |
| print(f"Atlas is not a dict, type: {type(data)}") |
| else: |
| print(f"Atlas file not found at {ATLAS_PATH}") |
| atlas = {} |
|
|
| |
| if not atlas or not isinstance(atlas, dict): |
| print(f"Warning: Atlas is empty or invalid, using fallback") |
| atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} |
|
|
| words = list(atlas.keys()) |
| print(f"Loaded atlas with {len(words)} words") |
| if len(words) < 2: |
| print(f"Warning: Not enough words in atlas ({len(words)}), using fallback") |
| atlas = {'word1': np.random.randn(256, 256), 'word2': np.random.randn(256, 256)} |
| words = list(atlas.keys()) |
|
|
| |
| first_val = np.array(atlas[words[0]]) |
| if first_val.shape == (256, 256): |
| plv_matrix = np.array([np.array(atlas[w]).flatten() for w in words]) |
| else: |
| plv_matrix = np.array([np.array(atlas[w]) for w in words]) |
|
|
| |
| if len(plv_matrix.shape) == 1 or plv_matrix.shape[0] < 2: |
| print(f"Warning: Invalid PLV matrix shape {plv_matrix.shape}, using fallback") |
| plv_matrix = np.random.randn(10, 65536) |
|
|
| pca = PCA(n_components=min(10, plv_matrix.shape[0] - 1)) |
| pca.fit(plv_matrix) |
| pc1_axis = pca.components_[0] |
| pc1_axis = pc1_axis / np.linalg.norm(pc1_axis) |
| global_mean = plv_matrix.mean(axis=0) |
|
|
| system = { |
| 'model': model, |
| 'tokenizer': tokenizer, |
| 'adapter': adapter, |
| 'axis': torch.tensor(pc1_axis, dtype=torch.float32).to(device), |
| 'global_mean': torch.tensor(global_mean, dtype=torch.float32).to(device), |
| 'device': device |
| } |
| return system |
|
|
| @spaces.GPU(duration=60) |
| def generate_variants(prompt, scenario, max_tokens): |
| """Generate all three variants""" |
| sys = load_system() |
|
|
| if scenario == "Educational": |
| prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" |
| alpha_strength = 5.0 |
| elif scenario == "Technical writing": |
| prompt_formatted = f"<|user|>\n{prompt}\n<|assistant|>\n" |
| alpha_strength = 5.0 |
| else: |
| prompt_formatted = prompt |
| alpha_strength = 3.0 |
|
|
| outputs = [] |
| for alpha in [-alpha_strength, 0, alpha_strength]: |
| inputs = sys['tokenizer'](prompt_formatted, return_tensors='pt').to(sys['device']) |
| generated_ids = inputs.input_ids.clone() |
|
|
| for _ in range(max_tokens): |
| outputs_model = sys['model'](generated_ids, output_hidden_states=True) |
| hidden = outputs_model.hidden_states[-1][:, -1, :] |
|
|
| |
| adapter_dtype = next(sys['adapter'].parameters()).dtype |
| hidden = hidden.to(adapter_dtype) |
|
|
| if alpha != 0: |
| hidden = hidden.detach().requires_grad_(True) |
| plv_pred = sys['adapter'](hidden) |
| score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) |
| grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] |
| grad = grad / (grad.norm() + 1e-8) |
| hidden = hidden.detach() + alpha * grad.detach() |
|
|
| with torch.no_grad(): |
| logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) |
| probs = torch.softmax(logits / 0.8, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
| if next_token.item() == sys['tokenizer'].eos_token_id: |
| break |
|
|
| text = sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) |
| if "<|assistant|>" in text: |
| text = text.split("<|assistant|>")[-1].strip() |
| outputs.append(text) |
|
|
| return outputs[0], outputs[1], outputs[2] |
|
|
| @spaces.GPU(duration=30) |
| def analyze_text(text): |
| """Analyze text and return score with visualization""" |
| sys = load_system() |
|
|
| with torch.no_grad(): |
| inputs = sys['tokenizer'](text, return_tensors='pt').to(sys['device']) |
| out = sys['model'](**inputs, output_hidden_states=True) |
| last_hidden = out.hidden_states[-1][0, -1, :] |
| |
| adapter_dtype = next(sys['adapter'].parameters()).dtype |
| last_hidden = last_hidden.to(adapter_dtype) |
| plv_pred = sys['adapter'](last_hidden.unsqueeze(0)) |
| plv_flat = plv_pred[0] |
| plv_centered = plv_flat - sys['global_mean'].to(adapter_dtype) |
| score = (plv_centered * sys['axis'].to(adapter_dtype)).sum().item() |
|
|
| |
| gauge_min = min(-300, score - 50) |
| gauge_max = max(300, score + 50) |
|
|
| fig = go.Figure(go.Indicator( |
| mode="number+gauge", |
| value=score, |
| gauge={ |
| 'shape': "angular", |
| 'axis': {'range': [gauge_min, gauge_max], 'tickwidth': 0.5, 'tickcolor': '#ccc'}, |
| 'bar': {'color': "#333", 'thickness': 0.15}, |
| 'bgcolor': "white", |
| 'borderwidth': 1, |
| 'bordercolor': "#e0e0e0", |
| 'steps': [ |
| {'range': [gauge_min, -5], 'color': "#e8f5e9"}, |
| {'range': [-5, 5], 'color': "#fafafa"}, |
| {'range': [5, gauge_max], 'color': "#fff3e0"} |
| ], |
| }, |
| number={'font': {'size': 36, 'color': '#000'}} |
| )) |
|
|
| fig.update_layout( |
| height=300, |
| width=400, |
| margin={'l': 30, 'r': 30, 't': 50, 'b': 30}, |
| paper_bgcolor='white', |
| font={'color': '#666'} |
| ) |
|
|
| if score > 5: |
| interpretation = "**Syntactic dominance** \nText patterns match brain activity during grammatical processing" |
| elif score < -5: |
| interpretation = "**Semantic dominance** \nText patterns match brain activity during meaning comprehension" |
| else: |
| interpretation = "**Balanced** \nMixed patterns - both structure and meaning equally present" |
|
|
| |
| plv_np = plv_pred[0].cpu().numpy() |
| plv_matrix = plv_np[:65536].reshape(256, 256) |
|
|
| fig_plv = px.imshow( |
| plv_matrix, |
| color_continuous_scale='Viridis', |
| aspect='auto' |
| ) |
| fig_plv.update_layout( |
| coloraxis_showscale=True, |
| coloraxis=dict( |
| colorbar=dict( |
| thickness=10, |
| len=0.7, |
| title=dict(text="Synchrony", side="right"), |
| tickfont=dict(size=10) |
| ) |
| ), |
| margin={'l': 0, 'r': 40, 't': 10, 'b': 0}, |
| height=300 |
| ) |
| fig_plv.update_xaxes(visible=False) |
| fig_plv.update_yaxes(visible=False) |
|
|
| return fig, interpretation, score, fig_plv |
|
|
| @spaces.GPU(duration=60) |
| def generate_steered(prompt, alpha, max_tokens): |
| """Generate with custom steering""" |
| sys = load_system() |
|
|
| inputs = sys['tokenizer'](prompt, return_tensors='pt').to(sys['device']) |
| generated_ids = inputs.input_ids.clone() |
|
|
| for _ in range(max_tokens): |
| outputs_model = sys['model'](generated_ids, output_hidden_states=True) |
| hidden = outputs_model.hidden_states[-1][:, -1, :] |
|
|
| |
| adapter_dtype = next(sys['adapter'].parameters()).dtype |
| hidden = hidden.to(adapter_dtype) |
|
|
| if alpha != 0: |
| hidden = hidden.detach().requires_grad_(True) |
| plv_pred = sys['adapter'](hidden) |
| score = torch.sum(plv_pred * sys['axis'].to(adapter_dtype)) |
| grad = torch.autograd.grad(score, hidden, retain_graph=False)[0] |
| grad = grad / (grad.norm() + 1e-8) |
| hidden = hidden.detach() + alpha * grad.detach() |
|
|
| with torch.no_grad(): |
| logits = sys['model'].lm_head(sys['model'].model.norm(hidden)) |
| probs = torch.softmax(logits / 0.8, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated_ids = torch.cat([generated_ids, next_token], dim=-1) |
| if next_token.item() == sys['tokenizer'].eos_token_id: |
| break |
|
|
| return sys['tokenizer'].decode(generated_ids[0], skip_special_tokens=True) |
|
|
| |
| custom_css = """ |
| <style> |
| @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600&display=swap'); |
| |
| /* Global font */ |
| .gradio-container, .gradio-container * { |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif !important; |
| } |
| |
| /* Clean header */ |
| .main-header { |
| font-size: 14px; |
| font-weight: 300; |
| letter-spacing: 2px; |
| text-transform: uppercase; |
| color: #666; |
| margin-bottom: 8px; |
| } |
| |
| .main-title { |
| font-size: 48px; |
| font-weight: 300; |
| line-height: 1.1; |
| letter-spacing: -1px; |
| margin-bottom: 16px; |
| } |
| |
| .subtitle { |
| font-size: 18px; |
| font-weight: 300; |
| color: #666; |
| line-height: 1.6; |
| } |
| |
| /* Clean tabs like Streamlit */ |
| .tabs { |
| border-bottom: 1px solid #e0e0e0 !important; |
| } |
| |
| .tab-nav button { |
| background: none !important; |
| border: none !important; |
| border-bottom: 2px solid transparent !important; |
| color: #666 !important; |
| font-weight: 400 !important; |
| font-size: 14px !important; |
| padding: 8px 16px !important; |
| text-transform: none !important; |
| } |
| |
| .tab-nav button.selected { |
| color: #000 !important; |
| border-bottom-color: #000 !important; |
| } |
| |
| /* Minimal buttons */ |
| button.primary { |
| background: white !important; |
| border: 1px solid #000 !important; |
| color: #000 !important; |
| font-weight: 400 !important; |
| padding: 10px 20px !important; |
| transition: all 0.2s !important; |
| } |
| |
| button.primary:hover { |
| background: #000 !important; |
| color: white !important; |
| } |
| |
| /* Clean textboxes */ |
| textarea, input[type="text"] { |
| border: 1px solid #e0e0e0 !important; |
| border-radius: 0 !important; |
| font-size: 14px !important; |
| } |
| |
| /* Section titles */ |
| .section-title { |
| font-size: 11px; |
| font-weight: 500; |
| letter-spacing: 1.5px; |
| text-transform: uppercase; |
| color: #999; |
| margin: 24px 0 16px 0; |
| } |
| |
| /* Value labels */ |
| .value-label { |
| font-size: 12px; |
| color: #999; |
| margin-bottom: 4px; |
| } |
| |
| /* Remove gradio branding */ |
| footer { display: none !important; } |
| </style> |
| """ |
|
|
| |
| DEFAULT_PROMPTS = { |
| "Technical writing": "Draft a short SMS to the customer informing them their payment has failed.", |
| "Educational": "Explain in 2 sentences what the butterfly effect is.", |
| "Free form": "Brainstorm creative uses of brain-steered language models in five bullet points." |
| } |
|
|
| SCENARIO_AXIS_TEXT = { |
| "Technical writing": { |
| "left_label": "Semantic / Content (meaning-heavy, concrete) [empathetic/actionable tone]", |
| "baseline_label": "Baseline", |
| "right_label": "Syntactic / Function (structure-heavy, abstract) [formal/policy tone]", |
| "left_caption": "*Steered toward meaning (brain semantic side)*", |
| "baseline_caption": "*No brain steering*", |
| "right_caption": "*Steered toward structure (brain syntactic side)*", |
| }, |
| "Educational": { |
| "left_label": "Semantic / Content (meaning-heavy, concrete) [analogy/concrete style]", |
| "baseline_label": "Baseline", |
| "right_label": "Syntactic / Function (structure-heavy, abstract) [definition/logical style]", |
| "left_caption": "*Steered toward meaning (brain semantic side)*", |
| "baseline_caption": "*No brain steering*", |
| "right_caption": "*Steered toward structure (brain syntactic side)*", |
| }, |
| "Free form": { |
| "left_label": "Semantic / Content (meaning-heavy, concrete)", |
| "baseline_label": "Baseline", |
| "right_label": "Syntactic / Function (structure-heavy, abstract)", |
| "left_caption": "*Steered toward meaning (brain semantic side)*", |
| "baseline_caption": "*No brain steering*", |
| "right_caption": "*Steered toward structure (brain syntactic side)*", |
| }, |
| } |
|
|
| with gr.Blocks( |
| title="Cognitive Proxy", |
| theme=gr.themes.Base( |
| primary_hue="gray", |
| neutral_hue="gray", |
| text_size="md", |
| spacing_size="lg", |
| radius_size="none", |
| ), |
| css=custom_css |
| ) as demo: |
|
|
| |
| gr.HTML(""" |
| <div> |
| <div class="main-header">Neural Language Interface</div> |
| <div class="main-title">Cognitive Proxy</div> |
| <div class="subtitle">Steering language models through brain-derived coordinate spaces.<br> |
| Using MEG phase-locking patterns from 21 subjects as control geometry.</div> |
| <div style="color: #999; font-size: 13px; margin-top: 16px;">Sandro Andric</div> |
| <div style="color: #999; font-size: 11px; margin-top: 8px;">Demo model: TinyLlama-1.1B-Chat</div> |
| <div style="margin-top: 12px;"><a href="https://arxiv.org/abs/2512.19399" style="color: #666; font-size: 12px;">📄 Read our latest research on brain-LLM alignment</a></div> |
| </div> |
| """) |
|
|
| |
| with gr.Accordion("How this works", open=False): |
| gr.Markdown(""" |
| **What makes this special:** This AI is controlled by real human brain data. |
| We recorded brain activity from 21 people listening to stories, discovered how their brains organize language, |
| and now use those patterns to steer what the AI generates. |
| |
| **Try this:** |
| 1. Start with the **Compare** tab and choose **Educational** |
| 2. Click "Generate all variants" to see three versions side by side |
| 3. Notice how the left (concrete) version uses analogies while the right (abstract) uses logic |
| 4. The difference comes from steering along brain axes discovered from MEG recordings |
| |
| **The science:** Different brain regions activate for grammar vs meaning. |
| We project the AI's internal states into this brain coordinate system and steer along the axis. |
| """) |
|
|
| with gr.Tabs(): |
| |
| with gr.TabItem("Compare"): |
| gr.HTML('<div class="section-title">Comparative Analysis</div>') |
|
|
| gr.Markdown(""" |
| See how brain steering affects AI output. Try **Educational** to see the difference between |
| abstract explanations vs concrete analogies, or **Technical writing** to compare formal vs friendly tones. |
| All controlled by brain patterns from 21 human subjects. |
| """) |
|
|
| with gr.Row(): |
| scenario = gr.Dropdown( |
| choices=["Technical writing", "Educational", "Free form"], |
| value="Technical writing", |
| label="Scenario", |
| container=False |
| ) |
|
|
| prompt = gr.Textbox( |
| value=DEFAULT_PROMPTS["Technical writing"], |
| label="", |
| placeholder="Enter your prompt...", |
| lines=4 |
| ) |
|
|
| def update_prompt(selected): |
| return DEFAULT_PROMPTS.get(selected, DEFAULT_PROMPTS["Free form"]) |
|
|
| scenario.change( |
| update_prompt, |
| inputs=[scenario], |
| outputs=[prompt] |
| ) |
|
|
| with gr.Row(): |
| max_tokens = gr.Slider(20, 150, 80, label="Max tokens", container=False) |
| generate_btn = gr.Button("Generate all variants", variant="primary") |
|
|
| gr.HTML('<div style="margin-top: 24px;"></div>') |
|
|
| with gr.Row(): |
| with gr.Column(): |
| axis_text = SCENARIO_AXIS_TEXT["Technical writing"] |
| left_label = gr.HTML(f'<div class="value-label">{axis_text["left_label"]}</div>') |
| output_semantic = gr.Textbox( |
| label="", |
| lines=10, |
| interactive=False, |
| container=False |
| ) |
| left_caption = gr.Markdown(axis_text["left_caption"], elem_classes=["caption"]) |
|
|
| with gr.Column(): |
| baseline_label = gr.HTML(f'<div class="value-label">{axis_text["baseline_label"]}</div>') |
| output_baseline = gr.Textbox( |
| label="", |
| lines=10, |
| interactive=False, |
| container=False |
| ) |
| baseline_caption = gr.Markdown(axis_text["baseline_caption"], elem_classes=["caption"]) |
|
|
| with gr.Column(): |
| right_label = gr.HTML(f'<div class="value-label">{axis_text["right_label"]}</div>') |
| output_syntactic = gr.Textbox( |
| label="", |
| lines=10, |
| interactive=False, |
| container=False |
| ) |
| right_caption = gr.Markdown(axis_text["right_caption"], elem_classes=["caption"]) |
|
|
| def update_axis_labels(selected): |
| data = SCENARIO_AXIS_TEXT.get(selected, SCENARIO_AXIS_TEXT["Free form"]) |
| return ( |
| data["left_label"], |
| data["baseline_label"], |
| data["right_label"], |
| data["left_caption"], |
| data["baseline_caption"], |
| data["right_caption"], |
| ) |
|
|
| scenario.change( |
| update_axis_labels, |
| inputs=[scenario], |
| outputs=[left_label, baseline_label, right_label, left_caption, baseline_caption, right_caption], |
| ) |
|
|
| generate_btn.click( |
| generate_variants, |
| inputs=[prompt, scenario, max_tokens], |
| outputs=[output_semantic, output_baseline, output_syntactic] |
| ) |
|
|
| |
| with gr.TabItem("Inspect"): |
| gr.HTML('<div class="section-title">Brain Space Projection</div>') |
|
|
| gr.Markdown(""" |
| Enter any text to see how it aligns with brain patterns. The meter shows whether your text |
| activates brain regions associated with grammar/structure (positive) or meaning/content (negative). |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| text_input = gr.Textbox( |
| value="The scientist discovered", |
| label="", |
| placeholder="Enter text to analyze...", |
| lines=6 |
| ) |
| analyze_btn = gr.Button("Project", variant="primary") |
|
|
| with gr.Column(): |
| gauge_plot = gr.Plot(label="") |
| interpretation = gr.Markdown("") |
|
|
| with gr.Accordion("What the number means", open=False): |
| gr.Markdown(""" |
| - **Negative values (green)** = semantic/meaning focus |
| - **Positive values (amber)** = syntactic/grammar focus |
| - **Larger magnitude** = stronger pattern |
| - **Range** typically -300 to +300 |
| """) |
|
|
| with gr.Accordion("View brain connectivity pattern", open=False): |
| gr.Markdown(""" |
| Phase-Locking Value (PLV) shows how synchronized different brain regions are. |
| Brighter colors = stronger synchronization between sensor pairs. |
| Each pixel represents connectivity between two of 256 MEG sensors. |
| """) |
| plv_plot = gr.Plot(label="") |
|
|
| def analyze_text_wrapper(text): |
| fig, interp, _, fig_plv = analyze_text(text) |
| return fig, interp, fig_plv |
|
|
| analyze_btn.click( |
| analyze_text_wrapper, |
| inputs=[text_input], |
| outputs=[gauge_plot, interpretation, plv_plot] |
| ) |
|
|
| |
| with gr.TabItem("Steer"): |
| gr.HTML('<div class="section-title">Neural Steering</div>') |
|
|
| with gr.Row(): |
| with gr.Column(scale=2): |
| prompt_steer = gr.Textbox( |
| value="The scientist discovered", |
| label="", |
| placeholder="Enter prompt...", |
| lines=5 |
| ) |
|
|
| with gr.Column(scale=1): |
| gr.HTML('<div class="value-label">Tokens</div>') |
| tokens_steer = gr.Slider(20, 150, 60, label="", container=False) |
|
|
| gr.HTML('<div class="value-label">Alpha</div>') |
| alpha_steer = gr.Slider(-5.0, 5.0, 0.0, 0.5, label="", container=False) |
| gr.Markdown("*negative → semantic | positive → syntactic*", elem_classes=["caption"]) |
|
|
| steer_btn = gr.Button("Generate", variant="primary") |
|
|
| gr.HTML('<div class="section-title">Output</div>') |
| output_steer = gr.Textbox(label="", lines=8, interactive=False, container=False) |
|
|
| steer_btn.click( |
| generate_steered, |
| inputs=[prompt_steer, alpha_steer, tokens_steer], |
| outputs=[output_steer] |
| ) |
|
|
| |
| gr.HTML(""" |
| <div style="text-align: center; color: #999; font-size: 12px; padding: 40px 0 20px 0; border-top: 1px solid #e0e0e0; margin-top: 40px;"> |
| © 2025 Sandro Andric | <a href="https://ainthusiast.com" style="color: #999;">Ainthusiast.com</a> |
| </div> |
| """) |
|
|
| demo.launch() |
|
|