aoiandroid commited on
Commit
5fd8b16
·
verified ·
1 Parent(s): 24a3e6f

Upload glm_ocr_coreml.ipynb with huggingface_hub

Browse files
Files changed (1) hide show
  1. glm_ocr_coreml.ipynb +566 -0
glm_ocr_coreml.ipynb ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# GLM-OCR to CoreML Conversion\n",
8
+ "\n",
9
+ "This notebook converts the [GLM-OCR](https://huggingface.co/aoiandroid/GLM-OCR) model (image-to-text OCR) to CoreML for use on iOS/macOS.\n",
10
+ "\n",
11
+ "**Model**: Multimodal OCR (CogViT visual encoder + cross-modal connector + GLM-0.5B decoder). \n",
12
+ "**Output**: Vision encoder as CoreML (`vision_encoder.mlpackage`), plus tokenizer/config for app-side use.\n",
13
+ "\n",
14
+ "**Requirements**: Python 3.10+, PyTorch, transformers (main branch for GLM-OCR support), coremltools. Colab or local GPU recommended."
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": null,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Install dependencies (uncomment in Colab or fresh env).\n",
24
+ "# For reproducible builds: pip install -r glm_ocr_coreml_requirements.txt\n",
25
+ "# Or with versions:\n",
26
+ "# !pip install -q torch==2.3.0 torchvision==0.18.0\n",
27
+ "# !pip install -q \"git+https://github.com/huggingface/transformers.git@main\"\n",
28
+ "# !pip install -q coremltools==7.2\n",
29
+ "# !pip install -q huggingface_hub>=0.23.0 pillow>=10.3.0"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import os\n",
39
+ "from pathlib import Path\n",
40
+ "\n",
41
+ "import numpy as np\n",
42
+ "import torch\n",
43
+ "import coremltools as ct\n",
44
+ "from PIL import Image"
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "markdown",
49
+ "metadata": {},
50
+ "source": [
51
+ "## 1. Load model and processor\n",
52
+ "\n",
53
+ "Using `aoiandroid/GLM-OCR` (duplicate of `zai-org/GLM-OCR`). Ensure transformers supports GLM-OCR (install from main: `pip install git+https://github.com/huggingface/transformers.git`)."
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": null,
59
+ "metadata": {},
60
+ "outputs": [],
61
+ "source": [
62
+ "MODEL_ID = \"aoiandroid/GLM-OCR\" # or \"zai-org/GLM-OCR\"\n",
63
+ "OUTPUT_DIR = Path(\"./glm_ocr_coreml\")\n",
64
+ "OUTPUT_DIR.mkdir(parents=True, exist_ok=True)\n",
65
+ "\n",
66
+ "# Load processor and model (use float32 for tracing; bfloat16 may not trace well)\n",
67
+ "from transformers import AutoProcessor, AutoModelForImageTextToText\n",
68
+ "\n",
69
+ "processor = AutoProcessor.from_pretrained(MODEL_ID)\n",
70
+ "model = AutoModelForImageTextToText.from_pretrained(\n",
71
+ " MODEL_ID,\n",
72
+ " torch_dtype=torch.float32,\n",
73
+ ")\n",
74
+ "model.eval()\n",
75
+ "\n",
76
+ "# Vision config for input shape (default image_size=336)\n",
77
+ "vision_config = getattr(model.config, \"vision_config\", None)\n",
78
+ "image_size = 336\n",
79
+ "if vision_config is not None:\n",
80
+ " image_size = getattr(vision_config, \"image_size\", 336)\n",
81
+ "if isinstance(image_size, (list, tuple)):\n",
82
+ " image_size = image_size[0]\n",
83
+ "hidden_size = getattr(model.config, \"hidden_size\", None) or (getattr(model.config.text_config, \"hidden_size\", 1024) if getattr(model.config, \"text_config\", None) else 1024)\n",
84
+ "print(f\"Image size: {image_size}, hidden_size: {hidden_size}\")"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "markdown",
89
+ "metadata": {},
90
+ "source": [
91
+ "### 1.1 Model structure validation\n",
92
+ "\n",
93
+ "Verify that the loaded model has the expected attributes (`model.model`, `get_image_features`). Check for a language/decoder submodule for decoder export."
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "metadata": {},
100
+ "outputs": [],
101
+ "source": [
102
+ "# Model structure validation (required for decoder export)\n",
103
+ "print(\"=== Model structure ===\")\n",
104
+ "print(f\"Model class: {type(model).__name__}\")\n",
105
+ "print(f\"Public attributes: {[a for a in dir(model) if not a.startswith('_')]}\")\n",
106
+ "\n",
107
+ "inner = getattr(model, \"model\", None)\n",
108
+ "if inner is None:\n",
109
+ " raise RuntimeError(\"model.model not found. Inspect the loaded model structure.\")\n",
110
+ "\n",
111
+ "if not hasattr(inner, \"get_image_features\"):\n",
112
+ " raise RuntimeError(\n",
113
+ " \"get_image_features not found. Install transformers from main: \"\n",
114
+ " \"pip install git+https://github.com/huggingface/transformers.git\"\n",
115
+ " )\n",
116
+ "\n",
117
+ "print(f\"vision_config: {getattr(model.config, 'vision_config', 'N/A')}\")\n",
118
+ "print(f\"hidden_size: {getattr(model.config, 'hidden_size', 'N/A')}\")\n",
119
+ "\n",
120
+ "# For decoder: look for language/text/decoder submodule on model or model.model\n",
121
+ "decoder_candidates = [\"language_model\", \"text_model\", \"decoder\", \"model\"]\n",
122
+ "for name in decoder_candidates:\n",
123
+ " obj = getattr(model, name, None) or getattr(inner, name, None)\n",
124
+ " if obj is not None and hasattr(obj, \"forward\"):\n",
125
+ " print(f\"Decoder candidate: {name} (on model or model.model)\")\n",
126
+ "print(\"Structure validation OK\")"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "metadata": {},
132
+ "source": [
133
+ "## 2. Export vision encoder to CoreML\n",
134
+ "\n",
135
+ "The vision part of GLM-OCR turns `pixel_values` into hidden states consumed by the language model. We trace `get_image_features(pixel_values)` to obtain a CoreML vision encoder. The app can then run this and feed the outputs into a separate decoder or use the rest of the pipeline in Swift."
136
+ ]
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "execution_count": null,
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "# Wrapper: pixel_values -> last_hidden_state\n",
145
+ "# GlmOcrForConditionalGeneration has .model (GlmOcrModel) with get_image_features\n",
146
+ "class VisionEncoderWrapper(torch.nn.Module):\n",
147
+ " def __init__(self, parent_model):\n",
148
+ " super().__init__()\n",
149
+ " self.base = getattr(parent_model, \"model\", parent_model)\n",
150
+ " if not hasattr(self.base, \"get_image_features\"):\n",
151
+ " raise AttributeError(\"Loaded model has no get_image_features; ensure transformers supports GLM-OCR.\")\n",
152
+ "\n",
153
+ " def forward(self, pixel_values: torch.Tensor):\n",
154
+ " out = self.base.get_image_features(pixel_values=pixel_values)\n",
155
+ " return out.last_hidden_state\n",
156
+ "\n",
157
+ "wrapper = VisionEncoderWrapper(model)\n",
158
+ "wrapper.eval()\n",
159
+ "\n",
160
+ "batch, channels = 1, 3\n",
161
+ "dummy_pixel = torch.randn(batch, channels, image_size, image_size, dtype=torch.float32)\n",
162
+ "\n",
163
+ "with torch.no_grad():\n",
164
+ " traced = torch.jit.trace(\n",
165
+ " wrapper,\n",
166
+ " (dummy_pixel,),\n",
167
+ " check_trace=False,\n",
168
+ " strict=False,\n",
169
+ " )\n",
170
+ "# Check output shape\n",
171
+ "with torch.no_grad():\n",
172
+ " out = traced(dummy_pixel)\n",
173
+ "print(f\"Vision encoder output shape: {out.shape}\")"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "metadata": {},
180
+ "outputs": [],
181
+ "source": [
182
+ "# Convert vision encoder to CoreML\n",
183
+ "# Output shape (1, vision_seq_len, hidden_size) - use actual shape from trace\n",
184
+ "vision_seq_len = out.shape[1]\n",
185
+ "hidden_size = out.shape[2]\n",
186
+ "\n",
187
+ "input_types = [\n",
188
+ " ct.TensorType(\n",
189
+ " name=\"pixel_values\",\n",
190
+ " shape=(1, channels, image_size, image_size),\n",
191
+ " dtype=np.float32,\n",
192
+ " )\n",
193
+ "]\n",
194
+ "output_types = [ct.TensorType(name=\"vision_hidden_states\")]\n",
195
+ "\n",
196
+ "# Use iOS16 for reliability; set to iOS15 or iOS17 per target device if needed\n",
197
+ "vision_mlmodel = ct.convert(\n",
198
+ " traced,\n",
199
+ " inputs=input_types,\n",
200
+ " outputs=output_types,\n",
201
+ " convert_to=\"mlprogram\",\n",
202
+ " minimum_deployment_target=ct.target.iOS16,\n",
203
+ " compute_units=ct.ComputeUnit.ALL,\n",
204
+ ")\n",
205
+ "\n",
206
+ "vision_path = OUTPUT_DIR / \"vision_encoder.mlpackage\"\n",
207
+ "vision_mlmodel.save(str(vision_path))\n",
208
+ "print(f\"Saved vision encoder to {vision_path}\")"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "metadata": {},
215
+ "outputs": [],
216
+ "source": [
217
+ "# Save vision encoder spec for Swift (vision_seq_len, hidden_size, image_size)\n",
218
+ "import json\n",
219
+ "\n",
220
+ "model_spec = {\n",
221
+ " \"vision_encoder\": {\n",
222
+ " \"input\": {\n",
223
+ " \"name\": \"pixel_values\",\n",
224
+ " \"shape\": [1, 3, int(image_size), int(image_size)],\n",
225
+ " \"dtype\": \"float32\",\n",
226
+ " },\n",
227
+ " \"output\": {\n",
228
+ " \"name\": \"vision_hidden_states\",\n",
229
+ " \"shape\": [1, int(vision_seq_len), int(hidden_size)],\n",
230
+ " \"dtype\": \"float32\",\n",
231
+ " },\n",
232
+ " },\n",
233
+ " \"image_size\": int(image_size),\n",
234
+ " \"vision_seq_len\": int(vision_seq_len),\n",
235
+ " \"hidden_size\": int(hidden_size),\n",
236
+ " \"model_id\": MODEL_ID,\n",
237
+ "}\n",
238
+ "\n",
239
+ "spec_path = OUTPUT_DIR / \"model_spec.json\"\n",
240
+ "with open(spec_path, \"w\") as f:\n",
241
+ " json.dump(model_spec, f, indent=2)\n",
242
+ "print(f\"Model spec saved: {spec_path}\")"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "markdown",
247
+ "metadata": {},
248
+ "source": [
249
+ "## 3. Save processor and config\n",
250
+ "\n",
251
+ "Copy tokenizer and config so the app can run preprocessing and decoding. Full autoregressive decoding (image + prompt -> text) would require either exporting the decoder as a second CoreML model or implementing the generation loop in Swift using the vision encoder output."
252
+ ]
253
+ },
254
+ {
255
+ "cell_type": "code",
256
+ "execution_count": null,
257
+ "metadata": {},
258
+ "outputs": [],
259
+ "source": [
260
+ "# Save processor (tokenizer + image processor) and config to output dir\n",
261
+ "processor.save_pretrained(OUTPUT_DIR)\n",
262
+ "model.config.save_pretrained(OUTPUT_DIR)\n",
263
+ "print(f\"Saved processor and config to {OUTPUT_DIR}\")\n",
264
+ "print(\"Contents:\", list(OUTPUT_DIR.iterdir()))"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "markdown",
269
+ "metadata": {},
270
+ "source": [
271
+ "## 4. Verify CoreML I/O (optional)\n",
272
+ "\n",
273
+ "Inspect input/output names and shapes for integration in Swift."
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": [
282
+ "loaded = ct.models.MLModel(str(vision_path))\n",
283
+ "spec = loaded.get_spec()\n",
284
+ "print(\"Vision encoder inputs:\", [d.name for d in spec.description.input])\n",
285
+ "print(\"Vision encoder outputs:\", [d.name for d in spec.description.output])"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "metadata": {},
291
+ "source": [
292
+ "## 5. Optional: Decoder or full-model export\n",
293
+ "\n",
294
+ "The full GLM-OCR pipeline (image + prompt -> generated text) uses `model.generate()` with cache and variable sequence length, which is hard to export as a single CoreML model. Options:\n",
295
+ "\n",
296
+ "- **Vision encoder only** (done above): Use `vision_encoder.mlpackage` in the app and implement the decoder/generation loop in Swift, or call a separate decoder CoreML if you export it.\n",
297
+ "- **Decoder export**: Trace the text model with fixed `encoder_hidden_states` (from the vision encoder output) and `input_ids` to get logits; then run autoregressive generation in the app. This requires building a wrapper that takes (input_ids, encoder_hidden_states, attention_mask) and returns logits, similar to T5/encoder-decoder conversion scripts.\n",
298
+ "- **Quantization**: Use `coremltools.optimize.coreml.palettize_weights` or `linear_quantize_weights` to reduce vision encoder size (e.g. INT8 or 4-bit)."
299
+ ]
300
+ },
301
+ {
302
+ "cell_type": "markdown",
303
+ "metadata": {},
304
+ "source": [
305
+ "### 2.1 Quantization (FP16 / INT8) and size comparison\n",
306
+ "\n",
307
+ "Apply FP16 and INT8 quantization to reduce vision encoder size for iOS. **After INT8 quantization, run the accuracy verification cell (Section 6) below.**"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": null,
313
+ "metadata": {},
314
+ "outputs": [],
315
+ "source": [
316
+ "import shutil\n",
317
+ "from coremltools.optimize.coreml import (\n",
318
+ " linear_quantize_weights,\n",
319
+ " OptimizationConfig,\n",
320
+ " OpLinearQuantizerConfig,\n",
321
+ ")\n",
322
+ "\n",
323
+ "# FP16 (minimal accuracy loss)\n",
324
+ "vision_fp16 = ct.models.MLModel(str(vision_path))\n",
325
+ "vision_fp16_path = OUTPUT_DIR / \"vision_encoder_fp16.mlpackage\"\n",
326
+ "try:\n",
327
+ " q16 = ct.models.neural_network.quantization_utils.quantize_weights(vision_fp16, nbits=16)\n",
328
+ " q16.save(str(vision_fp16_path))\n",
329
+ "except Exception as e:\n",
330
+ " print(f\"FP16 quantization failed: {e}\")\n",
331
+ " vision_fp16_path = None\n",
332
+ "\n",
333
+ "# INT8 (smaller; run accuracy verification after)\n",
334
+ "config = OptimizationConfig(\n",
335
+ " global_config=OpLinearQuantizerConfig(mode=\"linear_symmetric\", weight_threshold=512)\n",
336
+ ")\n",
337
+ "vision_int8 = linear_quantize_weights(vision_mlmodel, config)\n",
338
+ "vision_int8_path = OUTPUT_DIR / \"vision_encoder_int8.mlpackage\"\n",
339
+ "vision_int8.save(str(vision_int8_path))\n",
340
+ "\n",
341
+ "# Size comparison (MB)\n",
342
+ "for label, path in [\n",
343
+ " (\"FP32 (original)\", vision_path),\n",
344
+ " (\"FP16\", vision_fp16_path),\n",
345
+ " (\"INT8\", vision_int8_path),\n",
346
+ "]:\n",
347
+ " if path is not None and path.exists():\n",
348
+ " size_mb = sum(f.stat().st_size for f in path.rglob(\"*\") if f.is_file()) / 1e6\n",
349
+ " print(f\"{label}: {size_mb:.1f} MB\")"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "markdown",
354
+ "metadata": {},
355
+ "source": [
356
+ "## 6. Accuracy verification (PyTorch vs CoreML)\n",
357
+ "\n",
358
+ "Compare vision encoder outputs: PyTorch traced model vs CoreML. Use a test image (or a dummy image if `test_image.png` is missing). Cosine similarity per token should be close to 1.0."
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "code",
363
+ "execution_count": null,
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "from numpy.linalg import norm\n",
368
+ "\n",
369
+ "# Test image: use test_image.png if present, else dummy (shape-only check)\n",
370
+ "test_image_path = Path(\"test_image.png\")\n",
371
+ "if test_image_path.exists():\n",
372
+ " test_image = Image.open(test_image_path).convert(\"RGB\")\n",
373
+ " inputs = processor(images=test_image, return_tensors=\"pt\")\n",
374
+ " pixel_values = inputs[\"pixel_values\"].to(torch.float32)\n",
375
+ " if pixel_values.shape[2] != image_size or pixel_values.shape[3] != image_size:\n",
376
+ " pixel_values = torch.nn.functional.interpolate(\n",
377
+ " pixel_values, size=(image_size, image_size), mode=\"bilinear\"\n",
378
+ " )\n",
379
+ "else:\n",
380
+ " pixel_values = torch.randn(1, 3, image_size, image_size, dtype=torch.float32)\n",
381
+ " print(\"No test_image.png; using dummy tensor for shape verification.\")\n",
382
+ "\n",
383
+ "# PyTorch output\n",
384
+ "with torch.no_grad():\n",
385
+ " pt_out = traced(pixel_values).numpy()\n",
386
+ "\n",
387
+ "# CoreML output (FP32 model)\n",
388
+ "pv_np = pixel_values.cpu().numpy() if pixel_values.is_cuda else pixel_values.numpy()\n",
389
+ "coreml_out = vision_mlmodel.predict({\"pixel_values\": pv_np})[\"vision_hidden_states\"]\n",
390
+ "\n",
391
+ "# Cosine similarity per token (average and min)\n",
392
+ "cos_sims = []\n",
393
+ "for i in range(pt_out.shape[1]):\n",
394
+ " a, b = pt_out[0, i], coreml_out[0, i]\n",
395
+ " n = norm(a) * norm(b)\n",
396
+ " cos_sims.append(np.dot(a, b) / n if n > 0 else 1.0)\n",
397
+ "print(f\"Cosine similarity (PyTorch vs CoreML FP32) mean: {np.mean(cos_sims):.6f}, min: {np.min(cos_sims):.6f}\")\n",
398
+ "assert np.mean(cos_sims) > 0.999, \"Accuracy drop too large; check conversion settings.\"\n",
399
+ "print(\"Accuracy verification OK\")"
400
+ ]
401
+ },
402
+ {
403
+ "cell_type": "markdown",
404
+ "metadata": {},
405
+ "source": [
406
+ "## 7. Decoder export (single-step, optional)\n",
407
+ "\n",
408
+ "Export a one-step decoder: `(input_ids, encoder_hidden_states, attention_mask) -> logits`, so the app can run an autoregressive loop in Swift. **GLM-OCR may not expose a separate decoder API** (it merges vision and text in one forward). If trace fails, only the vision encoder is used; implement generation in Swift or call the full model in Python."
409
+ ]
410
+ },
411
+ {
412
+ "cell_type": "code",
413
+ "execution_count": null,
414
+ "metadata": {},
415
+ "outputs": [],
416
+ "source": [
417
+ "# Decoder export: try single-step (input_ids, encoder_hidden_states, attention_mask) -> logits\n",
418
+ "# GLM-OCR may merge vision+text in one forward; we try building inputs_embeds from vision + text embeddings.\n",
419
+ "decoder_exported = False\n",
420
+ "DECODER_MAX_LEN = max(256, int(vision_seq_len) + 64) # ensure text segment exists for trace\n",
421
+ "\n",
422
+ "try:\n",
423
+ " inner = model.model\n",
424
+ " embed_fn = getattr(model, \"get_input_embeddings\", None) or getattr(inner, \"get_input_embeddings\", None)\n",
425
+ " if embed_fn is None:\n",
426
+ " raise AttributeError(\"No get_input_embeddings on model\")\n",
427
+ "\n",
428
+ " class DecoderStepWrapper(torch.nn.Module):\n",
429
+ " def __init__(self, parent_model):\n",
430
+ " super().__init__()\n",
431
+ " self.inner = parent_model.model\n",
432
+ " self.lm_head = parent_model.lm_head\n",
433
+ " self.embed = parent_model.get_input_embeddings()\n",
434
+ "\n",
435
+ " def forward(\n",
436
+ " self,\n",
437
+ " input_ids: torch.Tensor,\n",
438
+ " encoder_hidden_states: torch.Tensor,\n",
439
+ " attention_mask: torch.Tensor,\n",
440
+ " ):\n",
441
+ " # Assume sequence layout: [image tokens (vision_seq_len), text tokens (rest)]\n",
442
+ " seq_len = input_ids.shape[1]\n",
443
+ " if encoder_hidden_states.shape[1] != vision_seq_len:\n",
444
+ " raise ValueError(\"encoder_hidden_states seq len must match vision_seq_len\")\n",
445
+ " text_len = seq_len - vision_seq_len\n",
446
+ " if text_len <= 0:\n",
447
+ " text_emb = self.embed(input_ids)\n",
448
+ " inputs_embeds = encoder_hidden_states\n",
449
+ " else:\n",
450
+ " text_emb = self.embed(input_ids[:, vision_seq_len:])\n",
451
+ " inputs_embeds = torch.cat([encoder_hidden_states, text_emb], dim=1)\n",
452
+ " out = self.inner(\n",
453
+ " attention_mask=attention_mask,\n",
454
+ " inputs_embeds=inputs_embeds,\n",
455
+ " use_cache=False,\n",
456
+ " )\n",
457
+ " return self.lm_head(out.last_hidden_state)\n",
458
+ "\n",
459
+ " dec_wrapper = DecoderStepWrapper(model)\n",
460
+ " dec_wrapper.eval()\n",
461
+ " dummy_ids = torch.randint(0, 1000, (1, DECODER_MAX_LEN), dtype=torch.long)\n",
462
+ " dummy_enc = torch.randn(1, vision_seq_len, hidden_size, dtype=torch.float32)\n",
463
+ " dummy_attn = torch.ones(1, DECODER_MAX_LEN, dtype=torch.long)\n",
464
+ " with torch.no_grad():\n",
465
+ " dec_traced = torch.jit.trace(\n",
466
+ " dec_wrapper,\n",
467
+ " (dummy_ids, dummy_enc, dummy_attn),\n",
468
+ " check_trace=False,\n",
469
+ " strict=False,\n",
470
+ " )\n",
471
+ " print(\"Decoder trace OK; converting to CoreML...\")\n",
472
+ "except Exception as e:\n",
473
+ " print(f\"Decoder export skipped: {e}\")\n",
474
+ " print(\"Use vision encoder only; implement autoregressive decoding in Swift or run full model in Python.\")\n",
475
+ " dec_traced = None"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": null,
481
+ "metadata": {},
482
+ "outputs": [],
483
+ "source": [
484
+ "if dec_traced is not None:\n",
485
+ " dec_input_types = [\n",
486
+ " ct.TensorType(name=\"input_ids\", shape=(1, DECODER_MAX_LEN), dtype=np.int32),\n",
487
+ " ct.TensorType(name=\"encoder_hidden_states\", shape=(1, vision_seq_len, hidden_size), dtype=np.float32),\n",
488
+ " ct.TensorType(name=\"attention_mask\", shape=(1, DECODER_MAX_LEN), dtype=np.int32),\n",
489
+ " ]\n",
490
+ " dec_output_types = [ct.TensorType(name=\"logits\")]\n",
491
+ " decoder_mlmodel = ct.convert(\n",
492
+ " dec_traced,\n",
493
+ " inputs=dec_input_types,\n",
494
+ " outputs=dec_output_types,\n",
495
+ " convert_to=\"mlprogram\",\n",
496
+ " minimum_deployment_target=ct.target.iOS16,\n",
497
+ " compute_units=ct.ComputeUnit.ALL,\n",
498
+ " )\n",
499
+ " decoder_path = OUTPUT_DIR / \"decoder.mlpackage\"\n",
500
+ " decoder_mlmodel.save(str(decoder_path))\n",
501
+ " print(f\"Saved decoder to {decoder_path}\")\n",
502
+ " decoder_exported = True\n",
503
+ " # Update model_spec with decoder I/O\n",
504
+ " model_spec[\"decoder\"] = {\n",
505
+ " \"input\": {\"names\": [\"input_ids\", \"encoder_hidden_states\", \"attention_mask\"], \"shapes\": [(1, DECODER_MAX_LEN), (1, vision_seq_len, hidden_size), (1, DECODER_MAX_LEN)]},\n",
506
+ " \"output\": {\"name\": \"logits\", \"shape\": [1, DECODER_MAX_LEN, int(getattr(model.config, \"vocab_size\", getattr(model.config.text_config, \"vocab_size\", 59392)))]},\n",
507
+ " }\n",
508
+ " with open(spec_path, \"w\") as f:\n",
509
+ " json.dump(model_spec, f, indent=2)\n",
510
+ "else:\n",
511
+ " print(\"Decoder not exported; model_spec unchanged.\")"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "markdown",
516
+ "metadata": {},
517
+ "source": [
518
+ "## 8. Swift integration sketch\n",
519
+ "\n",
520
+ "Use the vision encoder (and optional decoder) in an iOS app as below. Add `vision_encoder.mlpackage` to the Xcode project; if the decoder was exported, add `decoder.mlpackage` and run an autoregressive loop."
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "code",
525
+ "execution_count": null,
526
+ "metadata": {},
527
+ "outputs": [],
528
+ "source": [
529
+ "swift_example = \"\"\"\n",
530
+ "// Swift: CoreML vision encoder + optional decoder loop\n",
531
+ "// 1. Add vision_encoder.mlpackage (and decoder.mlpackage if exported) to the Xcode project.\n",
532
+ "// 2. Preprocess image to 336x336 float32 and run vision encoder.\n",
533
+ "\n",
534
+ "import CoreML\n",
535
+ "import Vision\n",
536
+ "\n",
537
+ "let visionModel = try VisionEncoder(configuration: MLModelConfiguration())\n",
538
+ "let pixelValues = preprocessImage(uiImage) // shape (1, 3, 336, 336), Float32\n",
539
+ "\n",
540
+ "let input = VisionEncoderInput(pixel_values: pixelValues)\n",
541
+ "let output = try visionModel.prediction(input: input)\n",
542
+ "let hiddenStates = output.vision_hidden_states // (1, vision_seq_len, hidden_size)\n",
543
+ "\n",
544
+ "// Pass hiddenStates to the decoder for text generation:\n",
545
+ "// - If decoder.mlpackage was exported: load DecoderStep, then in a loop feed\n",
546
+ "// (input_ids, encoder_hidden_states, attention_mask) and take argmax(logits) for next token.\n",
547
+ "// - Otherwise implement the generation loop in Swift or call the full model elsewhere.\n",
548
+ "\"\"\"\n",
549
+ "print(swift_example)"
550
+ ]
551
+ }
552
+ ],
553
+ "metadata": {
554
+ "kernelspec": {
555
+ "display_name": "Python 3",
556
+ "language": "python",
557
+ "name": "python3"
558
+ },
559
+ "language_info": {
560
+ "name": "python",
561
+ "version": "3.10.0"
562
+ }
563
+ },
564
+ "nbformat": 4,
565
+ "nbformat_minor": 4
566
+ }