{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Qwen3.5-0.8B Two-Shot Embedding Explorer\n", "Generate descriptions via two-shot prompting, then re-encode the output to extract embeddings with actual semantic diversity.\n", "\n", "**Runtime: GPU (T4)**" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Qwen3.5 requires transformers from git main\n", "!pip install -q \"transformers @ git+https://github.com/huggingface/transformers.git@main\"\n", "!pip install -q accelerate torch matplotlib seaborn numpy scipy" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "import torch.nn.functional as F\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "from scipy.spatial.distance import cosine\n", "from typing import Optional\n", "import gc\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Device: {device}')\n", "if device.type == 'cuda':\n", " print(f'GPU: {torch.cuda.get_device_name()}')\n", " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Model" ] }, { "cell_type": "code", "metadata": {}, "source": [ "MODEL_ID = 'Qwen/Qwen3.5-0.8B'\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_ID,\n", " torch_dtype=torch.bfloat16,\n", " device_map='auto',\n", " trust_remote_code=True,\n", ")\n", "model.eval()\n", "\n", "num_layers = model.config.num_hidden_layers\n", "hidden_dim = model.config.hidden_size\n", "print(f'Layers: {num_layers}, Hidden dim: {hidden_dim}')\n", "print(f'Total hidden states: {num_layers + 1}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two-Shot Generation + Re-Encode Pipeline\n", "1. Build a two-shot chat prompt with examples\n", "2. Generate a description\n", "3. Re-encode the generated text (not the prompt) and extract all hidden states" ] }, { "cell_type": "code", "metadata": {}, "source": [ "class TwoShotEmbeddingExtractor:\n", " \"\"\"\n", " Two-shot generation -> re-encode pipeline.\n", " \n", " Step 1: Chat-template two-shot prompt -> generate description\n", " Step 2: Encode the GENERATED text alone -> extract hidden states\n", " \n", " This produces embeddings of the model's own description,\n", " which has far more semantic diversity than raw prompt encoding.\n", " \"\"\"\n", "\n", " def __init__(self, model, tokenizer, device, min_tokens=2):\n", " self.model = model\n", " self.tokenizer = tokenizer\n", " self.device = device\n", " self.min_tokens = min_tokens\n", " self.num_layers = model.config.num_hidden_layers + 1\n", " self.hidden_dim = model.config.hidden_size\n", "\n", " def build_twoshot_prompt(self, subject: str) -> str:\n", " \"\"\"Build two-shot chat prompt with visual description examples.\"\"\"\n", " messages = [\n", " {\n", " 'role': 'system',\n", " 'content': 'You describe scenes and subjects in exactly one sentence. '\n", " 'Be specific about visual features, lighting, colors, and composition.'\n", " },\n", " {\n", " 'role': 'user',\n", " 'content': 'Describe: a car on a highway'\n", " },\n", " {\n", " 'role': 'assistant',\n", " 'content': 'A silver sedan cruises along a sunlit four-lane highway '\n", " 'cutting through rolling green hills under a pale blue sky with wispy cirrus clouds.'\n", " },\n", " {\n", " 'role': 'user',\n", " 'content': 'Describe: a sunflower field'\n", " },\n", " {\n", " 'role': 'assistant',\n", " 'content': 'Thousands of tall sunflowers with bright yellow petals and dark brown centers '\n", " 'stand in dense rows across a flat field stretching to the horizon at golden hour.'\n", " },\n", " {\n", " 'role': 'user',\n", " 'content': f'Describe: {subject}'\n", " },\n", " ]\n", " return self.tokenizer.apply_chat_template(\n", " messages, tokenize=False, add_generation_prompt=True\n", " )\n", "\n", " @torch.no_grad()\n", " def generate_description(self, subject: str, max_new_tokens=80) -> str:\n", " \"\"\"Generate a one-sentence visual description via two-shot.\"\"\"\n", " prompt = self.build_twoshot_prompt(subject)\n", " inputs = self.tokenizer(prompt, return_tensors='pt').to(self.device)\n", "\n", " output_ids = self.model.generate(\n", " **inputs,\n", " max_new_tokens=max_new_tokens,\n", " do_sample=True,\n", " temperature=0.7,\n", " top_p=0.9,\n", " pad_token_id=self.tokenizer.eos_token_id,\n", " )\n", "\n", " # Decode only the new tokens\n", " new_tokens = output_ids[0][inputs['input_ids'].shape[1]:]\n", " description = self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()\n", " return description\n", "\n", " @torch.no_grad()\n", " def encode_text(self, text: str) -> dict:\n", " \"\"\"\n", " Encode text and return all hidden states.\n", " Pads ultra-short inputs to avoid conv1d crash in DeltaNet layers.\n", " \"\"\"\n", " inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n", " seq_len = inputs['input_ids'].shape[1]\n", "\n", " if seq_len < self.min_tokens:\n", " text = text + ' . .'\n", " inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n", " seq_len = inputs['input_ids'].shape[1]\n", "\n", " outputs = self.model(**inputs, output_hidden_states=True)\n", "\n", " hidden_states = outputs.hidden_states\n", " if hidden_states is None:\n", " raise RuntimeError('Model returned None for hidden_states.')\n", "\n", " input_ids = inputs['input_ids'][0]\n", " tokens = [self.tokenizer.decode(tid) for tid in input_ids]\n", "\n", " return {\n", " 'hidden_states': hidden_states,\n", " 'input_ids': input_ids,\n", " 'tokens': tokens,\n", " 'seq_len': len(tokens),\n", " }\n", "\n", " def pool_hidden_states(self, hidden_states, method='mean'):\n", " \"\"\"Pool across tokens for all layers. Returns [num_layers, hidden_dim].\"\"\"\n", " pooled = []\n", " for hs in hidden_states:\n", " hs = hs.squeeze(0) # [seq_len, hidden_dim]\n", " if method == 'mean':\n", " pooled.append(hs.mean(dim=0))\n", " elif method == 'last_token':\n", " pooled.append(hs[-1])\n", " elif method == 'max':\n", " pooled.append(hs.max(dim=0).values)\n", " else:\n", " raise ValueError(f'Unknown method: {method}')\n", " return torch.stack(pooled)\n", "\n", " def generate_and_encode(self, subject: str, method='mean') -> dict:\n", " \"\"\"\n", " Full pipeline: generate description, then re-encode it.\n", " Returns embeddings of the GENERATED text, not the prompt.\n", " \"\"\"\n", " description = self.generate_description(subject)\n", " data = self.encode_text(description)\n", " embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n", " return {\n", " 'embeddings': embeddings,\n", " 'description': description,\n", " 'tokens': data['tokens'],\n", " 'seq_len': data['seq_len'],\n", " }\n", "\n", " def encode_raw(self, text: str, method='mean') -> dict:\n", " \"\"\"\n", " Direct encode (no generation). For comparison baseline.\n", " \"\"\"\n", " data = self.encode_text(text)\n", " embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n", " return {\n", " 'embeddings': embeddings,\n", " 'description': text,\n", " 'tokens': data['tokens'],\n", " 'seq_len': data['seq_len'],\n", " }\n", "\n", "\n", "extractor = TwoShotEmbeddingExtractor(model, tokenizer, device)\n", "print(f'Extractor ready. {extractor.num_layers} layers, {extractor.hidden_dim}d')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Generation\n", "Quick sanity check that the two-shot pipeline produces good descriptions." ] }, { "cell_type": "code", "metadata": {}, "source": [ "test_subjects = [\n", " 'a cat on a windowsill',\n", " 'a red cube on a blue floor',\n", " 'an oil painting of a stormy sea',\n", " 'darkness',\n", " 'cheese',\n", "]\n", "\n", "print('Two-shot generation test:')\n", "print('=' * 70)\n", "for subject in test_subjects:\n", " desc = extractor.generate_description(subject)\n", " print(f'\\nSubject: {subject}')\n", " print(f'Generated: {desc}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Test Subjects\n", "Same groups as before. Each gets a two-shot generated description + raw encode for comparison." ] }, { "cell_type": "code", "metadata": {}, "source": [ "SUBJECT_GROUPS = {\n", " 'photorealistic': [\n", " 'a cat sitting on a windowsill in golden hour light',\n", " 'a mountain landscape at sunset with dramatic clouds',\n", " 'an elderly man with weathered skin and blue eyes',\n", " ],\n", " 'artistic': [\n", " 'an oil painting of a stormy sea',\n", " 'a quiet Japanese garden with cherry blossoms',\n", " 'abstract geometric shapes overlapping',\n", " ],\n", " 'semantic_shift': [\n", " 'a red cube on a blue floor',\n", " 'a blue cube on a red floor',\n", " 'a green sphere floating above a white plane',\n", " ],\n", " 'gibberish': [\n", " 'mxkrl vvtonp qazhif bwsdee lpoqnr yttmz',\n", " 'florpnax grindleby shovantic wumblecrax tazzifer',\n", " 'aaaa bbbb cccc dddd eeee ffff gggg hhhh',\n", " ],\n", " 'short': [\n", " 'taco',\n", " '1girl',\n", " 'cheese',\n", " 'cheddar bacon sub',\n", " ],\n", "}\n", "\n", "subjects = []\n", "subject_labels = []\n", "subject_groups = []\n", "for group_name, group_items in SUBJECT_GROUPS.items():\n", " for s in group_items:\n", " subjects.append(s)\n", " label = s[:40] + '...' if len(s) > 40 else s\n", " subject_labels.append(label)\n", " subject_groups.append(group_name)\n", "\n", "print(f'{len(subjects)} subjects across {len(SUBJECT_GROUPS)} groups')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate Descriptions + Extract Embeddings\n", "Both two-shot (generated) and raw (direct encode) for comparison." ] }, { "cell_type": "code", "metadata": {}, "source": [ "POOL_METHODS = ['mean', 'last_token']\n", "\n", "# Two-shot generated embeddings\n", "twoshot_embeddings = {method: {} for method in POOL_METHODS}\n", "twoshot_descriptions = {}\n", "twoshot_token_counts = {}\n", "\n", "# Raw direct-encode embeddings (baseline)\n", "raw_embeddings = {method: {} for method in POOL_METHODS}\n", "raw_token_counts = {}\n", "\n", "print('=== TWO-SHOT GENERATION + ENCODE ===')\n", "print('=' * 70)\n", "for i, subject in enumerate(subjects):\n", " print(f'\\n[{i+1}/{len(subjects)}] \"{subject_labels[i]}\"')\n", " for method in POOL_METHODS:\n", " result = extractor.generate_and_encode(subject, method=method)\n", " twoshot_embeddings[method][i] = result['embeddings'].float().cpu()\n", " if method == POOL_METHODS[0]:\n", " twoshot_descriptions[i] = result['description']\n", " twoshot_token_counts[i] = result['seq_len']\n", " print(f' -> \"{result[\"description\"][:80]}...\"' if len(result['description']) > 80 else f' -> \"{result[\"description\"]}\"')\n", "\n", "print('\\n\\n=== RAW DIRECT ENCODE (BASELINE) ===')\n", "print('=' * 70)\n", "for i, subject in enumerate(subjects):\n", " print(f'[{i+1}/{len(subjects)}] \"{subject_labels[i]}\"')\n", " for method in POOL_METHODS:\n", " result = extractor.encode_raw(subject, method=method)\n", " raw_embeddings[method][i] = result['embeddings'].float().cpu()\n", " if method == POOL_METHODS[0]:\n", " raw_token_counts[i] = result['seq_len']\n", "\n", "n_subjects = len(subjects)\n", "n_layers = extractor.num_layers\n", "\n", "print(f'\\nDone. {n_subjects} subjects, {n_layers} layers, {extractor.hidden_dim}d')\n", "print(f'Two-shot token counts: {list(twoshot_token_counts.values())}')\n", "print(f'Raw token counts: {list(raw_token_counts.values())}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cosine Similarity Matrices" ] }, { "cell_type": "code", "metadata": {}, "source": [ "def compute_pairwise_cosine(embeddings_dict, num_prompts, num_layers):\n", " sim_matrix = np.zeros((num_layers, num_prompts, num_prompts))\n", " for layer_idx in range(num_layers):\n", " for i in range(num_prompts):\n", " for j in range(num_prompts):\n", " if i == j:\n", " sim_matrix[layer_idx, i, j] = 1.0\n", " elif j > i:\n", " vec_i = embeddings_dict[i][layer_idx].numpy()\n", " vec_j = embeddings_dict[j][layer_idx].numpy()\n", " sim = 1.0 - cosine(vec_i, vec_j)\n", " sim_matrix[layer_idx, i, j] = sim\n", " sim_matrix[layer_idx, j, i] = sim\n", " return sim_matrix\n", "\n", "# Compute for both pipelines, both pooling methods\n", "twoshot_sim = {}\n", "raw_sim = {}\n", "for method in POOL_METHODS:\n", " twoshot_sim[method] = compute_pairwise_cosine(twoshot_embeddings[method], n_subjects, n_layers)\n", " raw_sim[method] = compute_pairwise_cosine(raw_embeddings[method], n_subjects, n_layers)\n", " print(f'{method}: twoshot {twoshot_sim[method].shape}, raw {raw_sim[method].shape}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Head-to-Head: Two-Shot vs Raw at Best Layer\n", "Side-by-side heatmaps showing how two-shot generation changes the similarity landscape." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def plot_comparison_heatmaps(twoshot_sim, raw_sim, labels, method, layer_idx):\n", " \"\"\"Side-by-side: raw vs two-shot at a specific layer.\"\"\"\n", " layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n", " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 10))\n", "\n", " sns.heatmap(\n", " raw_sim[layer_idx], xticklabels=labels, yticklabels=labels,\n", " vmin=-0.2, vmax=1.0, cmap='RdYlBu_r', annot=True, fmt='.2f',\n", " ax=ax1, square=True, annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n", " )\n", " ax1.set_title(f'RAW encode | {method} | {layer_name}', fontsize=14)\n", " ax1.tick_params(axis='x', rotation=90, labelsize=7)\n", " ax1.tick_params(axis='y', rotation=0, labelsize=7)\n", "\n", " sns.heatmap(\n", " twoshot_sim[layer_idx], xticklabels=labels, yticklabels=labels,\n", " vmin=-0.2, vmax=1.0, cmap='RdYlBu_r', annot=True, fmt='.2f',\n", " ax=ax2, square=True, annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n", " )\n", " ax2.set_title(f'TWO-SHOT encode | {method} | {layer_name}', fontsize=14)\n", " ax2.tick_params(axis='x', rotation=90, labelsize=7)\n", " ax2.tick_params(axis='y', rotation=0, labelsize=7)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "short_labels = [l[:30] for l in subject_labels]\n", "\n", "# Show at penultimate and final layer for last_token\n", "for layer_idx in [n_layers - 2, n_layers - 1]:\n", " plot_comparison_heatmaps(\n", " twoshot_sim['last_token'], raw_sim['last_token'],\n", " short_labels, 'last_token', layer_idx\n", " )" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Two-Shot Heatmap Grid (All Sampled Layers)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "def plot_heatmap_grid(sim_matrix, labels, method_name, title_prefix=''):\n", " n_layers = sim_matrix.shape[0]\n", " layers_to_show = sorted(set([\n", " 0, n_layers // 4, n_layers // 2,\n", " 3 * n_layers // 4, n_layers - 2, n_layers - 1,\n", " ]))\n", "\n", " fig, axes = plt.subplots(2, 3, figsize=(24, 20))\n", " axes = axes.flatten()\n", "\n", " for idx, (ax, layer_idx) in enumerate(zip(axes, layers_to_show)):\n", " layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n", " sns.heatmap(\n", " sim_matrix[layer_idx],\n", " xticklabels=labels, yticklabels=labels,\n", " vmin=-0.2, vmax=1.0, cmap='RdYlBu_r',\n", " annot=True, fmt='.2f', ax=ax, square=True,\n", " annot_kws={'size': 6}, cbar_kws={'shrink': 0.6},\n", " )\n", " ax.set_title(f'{title_prefix}{method_name} | {layer_name}', fontsize=13)\n", " ax.tick_params(axis='x', rotation=90, labelsize=7)\n", " ax.tick_params(axis='y', rotation=0, labelsize=7)\n", "\n", " for idx in range(len(layers_to_show), len(axes)):\n", " axes[idx].set_visible(False)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "for method in POOL_METHODS:\n", " print(f'\\n=== TWO-SHOT | {method.upper()} ===')\n", " plot_heatmap_grid(twoshot_sim[method], short_labels, method, 'twoshot | ')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Discriminability: Two-Shot vs Raw" ] }, { "cell_type": "code", "metadata": {}, "source": [ "def compute_discriminability(sim_matrix, group_labels):\n", " n_layers = sim_matrix.shape[0]\n", " n = sim_matrix.shape[1]\n", " within_sim = np.zeros(n_layers)\n", " between_sim = np.zeros(n_layers)\n", "\n", " for layer in range(n_layers):\n", " w_vals, b_vals = [], []\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " val = sim_matrix[layer, i, j]\n", " if group_labels[i] == group_labels[j]:\n", " w_vals.append(val)\n", " else:\n", " b_vals.append(val)\n", " within_sim[layer] = np.mean(w_vals) if w_vals else 0\n", " between_sim[layer] = np.mean(b_vals) if b_vals else 0\n", "\n", " return within_sim, between_sim, within_sim - between_sim\n", "\n", "\n", "fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n", "\n", "configs = [\n", " ('mean', raw_sim, 'RAW | mean'),\n", " ('mean', twoshot_sim, 'TWO-SHOT | mean'),\n", " ('last_token', raw_sim, 'RAW | last_token'),\n", " ('last_token', twoshot_sim, 'TWO-SHOT | last_token'),\n", "]\n", "\n", "for ax, (method, sim_dict, title) in zip(axes.flatten(), configs):\n", " within, between, gap = compute_discriminability(sim_dict[method], subject_groups)\n", " layer_x = np.arange(n_layers)\n", "\n", " ax.plot(layer_x, within, label='Within-group', color='#2196F3', linewidth=2)\n", " ax.plot(layer_x, between, label='Between-group', color='#FF5722', linewidth=2)\n", " ax.fill_between(layer_x, between, within, alpha=0.15, color='green')\n", " ax.plot(layer_x, gap, label='Gap', color='green', linewidth=2, linestyle='--')\n", "\n", " best = np.argmax(gap)\n", " ax.axvline(best, color='green', linestyle=':', alpha=0.5)\n", " ax.annotate(f'Best: L{best} ({gap[best]:.3f})', xy=(best, gap[best]),\n", " xytext=(best + 1, gap[best] + 0.02),\n", " arrowprops=dict(arrowstyle='->', color='green'),\n", " fontsize=9, color='green')\n", "\n", " ax.set_xlabel('Layer')\n", " ax.set_ylabel('Cosine Similarity')\n", " ax.set_title(title, fontsize=13)\n", " ax.legend(fontsize=8)\n", " ax.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Similarity Statistics: Two-Shot vs Raw" ] }, { "cell_type": "code", "metadata": {}, "source": [ "print('=' * 70)\n", "print('SIMILARITY STATISTICS COMPARISON')\n", "print('=' * 70)\n", "\n", "for method in POOL_METHODS:\n", " print(f'\\n--- {method.upper()} ---')\n", " for label, sim_dict in [('RAW', raw_sim), ('TWO-SHOT', twoshot_sim)]:\n", " # Use penultimate layer\n", " layer = n_layers - 2\n", " mat = sim_dict[method][layer]\n", " off_diag = mat[np.triu_indices(n_subjects, k=1)]\n", "\n", " print(f' {label} (L{layer}):')\n", " print(f' Mean sim: {off_diag.mean():.4f}')\n", " print(f' Std sim: {off_diag.std():.4f}')\n", " print(f' Min sim: {off_diag.min():.4f}')\n", " print(f' Max sim: {off_diag.max():.4f}')\n", " print(f' Near-zero (<0.05): {(off_diag < 0.05).sum()}')\n", " print(f' High (>0.9): {(off_diag > 0.9).sum()}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Norms & Effective Dimensionality (Two-Shot)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "for method, ax in zip(POOL_METHODS, axes):\n", " for i in range(n_subjects):\n", " norms = twoshot_embeddings[method][i].norm(dim=-1).numpy()\n", " ax.plot(range(n_layers), norms, alpha=0.6, label=short_labels[i][:20])\n", " ax.set_xlabel('Layer')\n", " ax.set_ylabel('L2 Norm')\n", " ax.set_title(f'TWO-SHOT | {method} | Embedding Norms')\n", " ax.grid(True, alpha=0.3)\n", " ax.legend(fontsize=6, loc='upper left')\n", "\n", "plt.tight_layout()\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "def effective_dimensionality(embeddings_list):\n", " mat = torch.stack(embeddings_list)\n", " mat = mat - mat.mean(dim=0)\n", " _, S, _ = torch.svd(mat)\n", " S = S / S.sum()\n", " return 1.0 / (S ** 2).sum().item()\n", "\n", "fig, ax = plt.subplots(figsize=(10, 5))\n", "\n", "for label, emb_dict, ls in [('raw', raw_embeddings, '--'), ('twoshot', twoshot_embeddings, '-')]:\n", " for method, color in zip(POOL_METHODS, ['#2196F3', '#FF5722']):\n", " eff_dims = []\n", " for layer_idx in range(n_layers):\n", " vecs = [emb_dict[method][i][layer_idx] for i in range(n_subjects)]\n", " eff_dims.append(effective_dimensionality(vecs))\n", " ax.plot(range(n_layers), eff_dims, label=f'{label} | {method}',\n", " linewidth=2, linestyle=ls, color=color)\n", "\n", "ax.set_xlabel('Layer')\n", "ax.set_ylabel('Effective Dimensionality')\n", "ax.set_title('Effective Rank: Raw (dashed) vs Two-Shot (solid)')\n", "ax.legend()\n", "ax.grid(True, alpha=0.3)\n", "plt.tight_layout()\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Summary" ] }, { "cell_type": "code", "metadata": {}, "source": [ "print('=' * 70)\n", "print('TWO-SHOT vs RAW EMBEDDING SUMMARY')\n", "print('=' * 70)\n", "print(f'Model: {MODEL_ID}')\n", "print(f'Layers: {n_layers}, Hidden dim: {extractor.hidden_dim}')\n", "print(f'Subjects: {n_subjects}')\n", "print()\n", "\n", "for method in POOL_METHODS:\n", " print(f'--- {method.upper()} ---')\n", " for label, sim_dict in [('RAW', raw_sim), ('TWO-SHOT', twoshot_sim)]:\n", " within, between, gap = compute_discriminability(sim_dict[method], subject_groups)\n", " best_l = np.argmax(gap)\n", " print(f' {label}:')\n", " print(f' Best layer: L{best_l} (gap = {gap[best_l]:.4f})')\n", " print(f' Final layer gap: {gap[-1]:.4f}')\n", " print()\n", "\n", "print('\\nGENERATED DESCRIPTIONS:')\n", "for i, subject in enumerate(subjects):\n", " print(f' [{subject_labels[i]}]')\n", " print(f' -> {twoshot_descriptions[i]}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## (Optional) Export" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Uncomment to save\n", "# export = {\n", "# 'model_id': MODEL_ID,\n", "# 'subjects': subjects,\n", "# 'subject_groups': subject_groups,\n", "# 'twoshot_descriptions': twoshot_descriptions,\n", "# 'twoshot_embeddings': twoshot_embeddings,\n", "# 'raw_embeddings': raw_embeddings,\n", "# 'twoshot_sim': twoshot_sim,\n", "# 'raw_sim': raw_sim,\n", "# 'num_layers': n_layers,\n", "# 'hidden_dim': extractor.hidden_dim,\n", "# }\n", "# torch.save(export, 'qwen35_twoshot_embeddings.pt')\n", "# print('Saved.')" ], "execution_count": null, "outputs": [] } ] }