{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Qwen3.5-0.8B Embedding Explorer\n", "Extract all-layer embeddings, compare prompt similarity, and evaluate potential for diffusion conditioning.\n", "\n", "**Runtime: GPU (T4 is fine for 0.8B)**" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Qwen3.5 requires transformers from git main (not yet in PyPI release)\n", "!pip install -q \"transformers @ git+https://github.com/huggingface/transformers.git@main\"\n", "!pip install -q accelerate torch matplotlib seaborn numpy scipy" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": {}, "source": [ "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from transformers import AutoTokenizer, AutoModelForCausalLM\n", "from scipy.spatial.distance import cosine\n", "from typing import Optional\n", "import gc\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "print(f'Device: {device}')\n", "if device.type == 'cuda':\n", " print(f'GPU: {torch.cuda.get_device_name()}')\n", " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load Model" ] }, { "cell_type": "code", "metadata": {}, "source": [ "MODEL_ID = 'Qwen/Qwen3.5-0.8B'\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)\n", "model = AutoModelForCausalLM.from_pretrained(\n", " MODEL_ID,\n", " torch_dtype=torch.bfloat16,\n", " device_map='auto',\n", " trust_remote_code=True,\n", " output_hidden_states=True, # Critical: get all layer outputs\n", ")\n", "model.eval()\n", "\n", "num_layers = model.config.num_hidden_layers\n", "hidden_dim = model.config.hidden_size\n", "print(f'Layers: {num_layers}, Hidden dim: {hidden_dim}')\n", "print(f'Total hidden states returned: {num_layers + 1} (embedding layer + {num_layers} transformer layers)')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Embedding Extraction Engine\n", "Extracts hidden states from all layers with multiple pooling strategies." ] }, { "cell_type": "code", "metadata": {}, "source": [ "class QwenEmbeddingExtractor:\n", " \"\"\"Extract and pool hidden states from all layers of Qwen3.5-0.8B.\"\"\"\n", "\n", " def __init__(self, model, tokenizer, device):\n", " self.model = model\n", " self.tokenizer = tokenizer\n", " self.device = device\n", " self.num_layers = model.config.num_hidden_layers + 1 # +1 for embedding layer\n", " self.hidden_dim = model.config.hidden_size\n", "\n", " @torch.no_grad()\n", " def extract_hidden_states(self, text: str) -> dict:\n", " \"\"\"\n", " Run forward pass and return all hidden states + metadata.\n", "\n", " Returns dict with:\n", " - hidden_states: tuple of (num_layers+1) tensors, each [1, seq_len, hidden_dim]\n", " - input_ids: token IDs\n", " - tokens: decoded token strings\n", " - seq_len: number of tokens\n", " \"\"\"\n", " inputs = self.tokenizer(text, return_tensors='pt').to(self.device)\n", " outputs = self.model(**inputs)\n", "\n", " hidden_states = outputs.hidden_states # tuple of (num_layers+1) tensors\n", " input_ids = inputs['input_ids'][0]\n", " tokens = [self.tokenizer.decode(tid) for tid in input_ids]\n", "\n", " return {\n", " 'hidden_states': hidden_states,\n", " 'input_ids': input_ids,\n", " 'tokens': tokens,\n", " 'seq_len': len(tokens),\n", " }\n", "\n", " def pool_hidden_states(\n", " self,\n", " hidden_states: tuple,\n", " method: str = 'mean',\n", " layer_indices: Optional[list] = None,\n", " ) -> torch.Tensor:\n", " \"\"\"\n", " Pool hidden states across tokens for specified layers.\n", "\n", " Args:\n", " hidden_states: tuple from extract_hidden_states\n", " method: 'mean', 'last_token', 'max', or 'all_tokens'\n", " layer_indices: which layers to return (None = all)\n", "\n", " Returns:\n", " For 'all_tokens': [num_layers, seq_len, hidden_dim]\n", " Otherwise: [num_layers, hidden_dim]\n", " \"\"\"\n", " if layer_indices is None:\n", " layer_indices = list(range(len(hidden_states)))\n", "\n", " pooled = []\n", " for idx in layer_indices:\n", " hs = hidden_states[idx].squeeze(0) # [seq_len, hidden_dim]\n", "\n", " if method == 'mean':\n", " pooled.append(hs.mean(dim=0)) # [hidden_dim]\n", " elif method == 'last_token':\n", " pooled.append(hs[-1]) # [hidden_dim]\n", " elif method == 'max':\n", " pooled.append(hs.max(dim=0).values) # [hidden_dim]\n", " elif method == 'all_tokens':\n", " pooled.append(hs) # [seq_len, hidden_dim]\n", " else:\n", " raise ValueError(f'Unknown pooling method: {method}')\n", "\n", " return torch.stack(pooled)\n", "\n", " def extract_and_pool(self, text: str, method: str = 'mean') -> dict:\n", " \"\"\"\n", " Convenience: extract + pool in one call.\n", "\n", " Returns dict with:\n", " - embeddings: [num_layers, hidden_dim] (or [num_layers, seq_len, hidden_dim] for all_tokens)\n", " - tokens: list of token strings\n", " - seq_len: int\n", " \"\"\"\n", " data = self.extract_hidden_states(text)\n", " embeddings = self.pool_hidden_states(data['hidden_states'], method=method)\n", " return {\n", " 'embeddings': embeddings,\n", " 'tokens': data['tokens'],\n", " 'seq_len': data['seq_len'],\n", " }\n", "\n", "extractor = QwenEmbeddingExtractor(model, tokenizer, device)\n", "print(f'Extractor ready. Will return {extractor.num_layers} layer embeddings per prompt.')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define Test Prompts\n", "Edit these to whatever you want to compare. Grouped by semantic category to see clustering behavior." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# ---- EDIT THESE ----\n", "# Groups help visualize clustering. Flat list is also fine.\n", "PROMPT_GROUPS = {\n", " 'photorealistic': [\n", " 'a photograph of a cat sitting on a windowsill in golden hour light',\n", " 'professional photo of a mountain landscape at sunset with dramatic clouds',\n", " 'close-up portrait of an elderly man with weathered skin and blue eyes',\n", " ],\n", " 'artistic': [\n", " 'an oil painting of a stormy sea in the style of Turner',\n", " 'watercolor illustration of a quiet Japanese garden with cherry blossoms',\n", " 'abstract geometric composition with overlapping translucent shapes',\n", " ],\n", " 'semantic_shift': [\n", " 'a red cube on a blue floor',\n", " 'a blue cube on a red floor',\n", " 'a green sphere floating above a white plane',\n", " ],\n", " 'edge_cases': [\n", " 'darkness',\n", " '', # empty string baseline\n", " 'asdfghjkl random noise tokens xyzzy',\n", " ],\n", "}\n", "\n", "# Flatten for processing\n", "prompts = []\n", "prompt_labels = []\n", "prompt_groups = []\n", "for group_name, group_prompts in PROMPT_GROUPS.items():\n", " for p in group_prompts:\n", " prompts.append(p)\n", " label = p[:50] + '...' if len(p) > 50 else p\n", " label = label if label else ''\n", " prompt_labels.append(label)\n", " prompt_groups.append(group_name)\n", "\n", "print(f'{len(prompts)} prompts across {len(PROMPT_GROUPS)} groups')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Extract All Embeddings" ] }, { "cell_type": "code", "metadata": {}, "source": [ "POOL_METHODS = ['mean', 'last_token']\n", "\n", "# Store results: {method: {prompt_idx: [num_layers, hidden_dim]}}\n", "all_embeddings = {method: {} for method in POOL_METHODS}\n", "token_counts = {}\n", "\n", "for i, prompt in enumerate(prompts):\n", " print(f'[{i+1}/{len(prompts)}] ({len(prompt)} chars) \"{prompt_labels[i]}\"')\n", " for method in POOL_METHODS:\n", " result = extractor.extract_and_pool(prompt, method=method)\n", " all_embeddings[method][i] = result['embeddings'].float().cpu() # [num_layers, hidden_dim]\n", " if method == POOL_METHODS[0]:\n", " token_counts[i] = result['seq_len']\n", "\n", "print(f'\\nDone. Shape per prompt per method: {all_embeddings[\"mean\"][0].shape}')\n", "print(f'Token counts: {list(token_counts.values())}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cosine Similarity Analysis\n", "Compute pairwise similarity at every layer, for each pooling method." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def compute_pairwise_cosine(embeddings_dict, num_prompts, num_layers):\n", " \"\"\"\n", " Compute cosine similarity between all prompt pairs at each layer.\n", "\n", " Returns: [num_layers, num_prompts, num_prompts] numpy array\n", " \"\"\"\n", " sim_matrix = np.zeros((num_layers, num_prompts, num_prompts))\n", "\n", " for layer_idx in range(num_layers):\n", " for i in range(num_prompts):\n", " for j in range(num_prompts):\n", " if i == j:\n", " sim_matrix[layer_idx, i, j] = 1.0\n", " elif j > i:\n", " vec_i = embeddings_dict[i][layer_idx].numpy()\n", " vec_j = embeddings_dict[j][layer_idx].numpy()\n", " sim = 1.0 - cosine(vec_i, vec_j)\n", " sim_matrix[layer_idx, i, j] = sim\n", " sim_matrix[layer_idx, j, i] = sim\n", "\n", " return sim_matrix\n", "\n", "n_prompts = len(prompts)\n", "n_layers = extractor.num_layers\n", "\n", "sim_matrices = {}\n", "for method in POOL_METHODS:\n", " sim_matrices[method] = compute_pairwise_cosine(\n", " all_embeddings[method], n_prompts, n_layers\n", " )\n", " print(f'{method}: similarity matrix shape = {sim_matrices[method].shape}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Heatmaps: Per-Layer Similarity\n", "Shows how prompt-pair similarity evolves across layers." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def plot_similarity_heatmaps(sim_matrix, labels, method_name, layers_to_show=None):\n", " \"\"\"\n", " Plot similarity heatmaps for selected layers.\n", " If layers_to_show is None, picks: first, 25%, 50%, 75%, last.\n", " \"\"\"\n", " n_layers = sim_matrix.shape[0]\n", "\n", " if layers_to_show is None:\n", " layers_to_show = sorted(set([\n", " 0,\n", " n_layers // 4,\n", " n_layers // 2,\n", " 3 * n_layers // 4,\n", " n_layers - 2, # penultimate\n", " n_layers - 1, # final\n", " ]))\n", "\n", " n_plots = len(layers_to_show)\n", " fig, axes = plt.subplots(1, n_plots, figsize=(6 * n_plots, 5))\n", " if n_plots == 1:\n", " axes = [axes]\n", "\n", " for ax, layer_idx in zip(axes, layers_to_show):\n", " layer_name = 'embed' if layer_idx == 0 else f'L{layer_idx}'\n", " sns.heatmap(\n", " sim_matrix[layer_idx],\n", " xticklabels=labels, yticklabels=labels,\n", " vmin=-0.2, vmax=1.0,\n", " cmap='RdYlBu_r', annot=True, fmt='.2f',\n", " ax=ax, square=True,\n", " cbar_kws={'shrink': 0.6},\n", " )\n", " ax.set_title(f'{method_name} | {layer_name}', fontsize=11)\n", " ax.tick_params(axis='x', rotation=45)\n", " ax.tick_params(axis='y', rotation=0)\n", "\n", " plt.tight_layout()\n", " plt.show()\n", "\n", "# Short labels for readability\n", "short_labels = [l[:30] for l in prompt_labels]\n", "\n", "for method in POOL_METHODS:\n", " print(f'\\n=== {method.upper()} POOLING ===')\n", " plot_similarity_heatmaps(sim_matrices[method], short_labels, method)" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Layer-wise Discriminability\n", "For each layer, compute average within-group similarity vs. between-group similarity.\n", "Higher gap = better semantic clustering = more useful for conditioning." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def compute_discriminability(sim_matrix, group_labels):\n", " \"\"\"\n", " Per-layer: avg within-group sim, avg between-group sim, and gap.\n", " Returns arrays of shape [num_layers].\n", " \"\"\"\n", " n_layers = sim_matrix.shape[0]\n", " n = sim_matrix.shape[1]\n", " unique_groups = list(set(group_labels))\n", "\n", " within_sim = np.zeros(n_layers)\n", " between_sim = np.zeros(n_layers)\n", "\n", " for layer in range(n_layers):\n", " w_vals, b_vals = [], []\n", " for i in range(n):\n", " for j in range(i + 1, n):\n", " val = sim_matrix[layer, i, j]\n", " if group_labels[i] == group_labels[j]:\n", " w_vals.append(val)\n", " else:\n", " b_vals.append(val)\n", " within_sim[layer] = np.mean(w_vals) if w_vals else 0\n", " between_sim[layer] = np.mean(b_vals) if b_vals else 0\n", "\n", " return within_sim, between_sim, within_sim - between_sim\n", "\n", "\n", "fig, axes = plt.subplots(1, len(POOL_METHODS), figsize=(8 * len(POOL_METHODS), 5))\n", "if len(POOL_METHODS) == 1:\n", " axes = [axes]\n", "\n", "best_layers = {}\n", "for ax, method in zip(axes, POOL_METHODS):\n", " within, between, gap = compute_discriminability(sim_matrices[method], prompt_groups)\n", "\n", " layer_x = np.arange(n_layers)\n", " ax.plot(layer_x, within, label='Within-group sim', color='#2196F3', linewidth=2)\n", " ax.plot(layer_x, between, label='Between-group sim', color='#FF5722', linewidth=2)\n", " ax.fill_between(layer_x, between, within, alpha=0.15, color='green')\n", " ax.plot(layer_x, gap, label='Gap (discriminability)', color='green', linewidth=2, linestyle='--')\n", "\n", " best_layer = np.argmax(gap)\n", " best_layers[method] = best_layer\n", " ax.axvline(best_layer, color='green', linestyle=':', alpha=0.5)\n", " ax.annotate(f'Best: L{best_layer}', xy=(best_layer, gap[best_layer]),\n", " xytext=(best_layer + 1, gap[best_layer] + 0.02),\n", " arrowprops=dict(arrowstyle='->', color='green'),\n", " fontsize=10, color='green')\n", "\n", " ax.set_xlabel('Layer Index')\n", " ax.set_ylabel('Cosine Similarity')\n", " ax.set_title(f'{method} pooling — Semantic Discriminability')\n", " ax.legend()\n", " ax.grid(True, alpha=0.3)\n", "\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print('\\nBest discriminability layers:')\n", "for method, layer in best_layers.items():\n", " print(f' {method}: layer {layer}')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Embedding Norm & Variance Across Layers\n", "Checks for collapse (all norms converging) or explosion — both bad for conditioning." ] }, { "cell_type": "code", "metadata": {}, "source": [ "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n", "\n", "# Norms per layer per prompt\n", "for method, ax in zip(POOL_METHODS, axes):\n", " for i in range(n_prompts):\n", " norms = all_embeddings[method][i].norm(dim=-1).numpy() # [num_layers]\n", " ax.plot(range(n_layers), norms, alpha=0.6, label=short_labels[i][:20])\n", "\n", " ax.set_xlabel('Layer')\n", " ax.set_ylabel('L2 Norm')\n", " ax.set_title(f'{method} pooling — Embedding Norms')\n", " ax.grid(True, alpha=0.3)\n", " ax.legend(fontsize=7, loc='upper left')\n", "\n", "plt.tight_layout()\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Effective Dimensionality per Layer\n", "How many dimensions are actually being used? Low rank = bad for diffusion conditioning diversity." ] }, { "cell_type": "code", "metadata": {}, "source": [ "def effective_dimensionality(embeddings_list):\n", " \"\"\"\n", " Compute effective dimensionality via participation ratio of singular values.\n", " embeddings_list: list of [hidden_dim] vectors\n", " Returns: float (effective rank)\n", " \"\"\"\n", " mat = torch.stack(embeddings_list) # [n_prompts, hidden_dim]\n", " mat = mat - mat.mean(dim=0) # center\n", " _, S, _ = torch.svd(mat)\n", " S = S / S.sum()\n", " participation_ratio = 1.0 / (S ** 2).sum().item()\n", " return participation_ratio\n", "\n", "\n", "for method in POOL_METHODS:\n", " eff_dims = []\n", " for layer_idx in range(n_layers):\n", " layer_vecs = [all_embeddings[method][i][layer_idx] for i in range(n_prompts)]\n", " ed = effective_dimensionality(layer_vecs)\n", " eff_dims.append(ed)\n", "\n", " plt.plot(range(n_layers), eff_dims, label=method, linewidth=2)\n", "\n", "plt.xlabel('Layer')\n", "plt.ylabel('Effective Dimensionality (participation ratio)')\n", "plt.title('Effective Rank of Embedding Space per Layer')\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.tight_layout()\n", "plt.show()" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Quick Diffusion Conditioning Assessment\n", "Summary: which layers look most promising as conditioning vectors?" ] }, { "cell_type": "code", "metadata": {}, "source": [ "print('=' * 70)\n", "print('DIFFUSION CONDITIONING VIABILITY SUMMARY')\n", "print('=' * 70)\n", "print(f'\\nModel: {MODEL_ID}')\n", "print(f'Layers: {n_layers} (0=input embeddings, rest=transformer layers)')\n", "print(f'Hidden dim: {hidden_dim}')\n", "print(f'Prompts tested: {n_prompts}')\n", "print()\n", "\n", "for method in POOL_METHODS:\n", " within, between, gap = compute_discriminability(sim_matrices[method], prompt_groups)\n", " best_l = np.argmax(gap)\n", " worst_l = np.argmin(gap)\n", "\n", " # Check for near-collapse: if all pairwise sims > 0.95 at any layer\n", " collapse_layers = []\n", " for l in range(n_layers):\n", " off_diag = sim_matrices[method][l][np.triu_indices(n_prompts, k=1)]\n", " if off_diag.min() > 0.95:\n", " collapse_layers.append(l)\n", "\n", " print(f'--- {method.upper()} POOLING ---')\n", " print(f' Best discriminability: Layer {best_l} (gap = {gap[best_l]:.4f})')\n", " print(f' Worst discriminability: Layer {worst_l} (gap = {gap[worst_l]:.4f})')\n", " print(f' Penultimate layer gap: {gap[-2]:.4f}')\n", " print(f' Final layer gap: {gap[-1]:.4f}')\n", " if collapse_layers:\n", " print(f' WARNING: Near-collapse at layers: {collapse_layers}')\n", " else:\n", " print(f' No collapse detected (all layers have some discrimination)')\n", " print()\n", "\n", "print('RECOMMENDATIONS:')\n", "print(' For POOLED conditioning (global vector): Use the best discriminability layer.')\n", "print(' For TOKEN-LEVEL conditioning (cross-attention): Re-run with method=\"all_tokens\"')\n", "print(' and compare token-level structure against T5/CLIP token outputs.')\n", "print(' Watch for: norm explosion in later layers (may need LayerNorm before conditioning).')\n", "print(' The penultimate layer often outperforms the final layer (CLIP effect).')" ], "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## (Optional) Export Embeddings for Further Analysis\n", "Save to disk for loading into your geometric pipeline." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Uncomment to save\n", "# export = {\n", "# 'model_id': MODEL_ID,\n", "# 'prompts': prompts,\n", "# 'prompt_groups': prompt_groups,\n", "# 'pool_methods': POOL_METHODS,\n", "# 'embeddings': {m: {i: all_embeddings[m][i] for i in range(n_prompts)} for m in POOL_METHODS},\n", "# 'sim_matrices': sim_matrices,\n", "# 'num_layers': n_layers,\n", "# 'hidden_dim': hidden_dim,\n", "# }\n", "# torch.save(export, 'qwen35_0.8b_embeddings.pt')\n", "# print('Saved to qwen35_0.8b_embeddings.pt')" ], "execution_count": null, "outputs": [] } ] }