Kuangdai commited on
Commit ·
6fb6c07
0
Parent(s):
Initial release of SoilFormer
Browse files- .gitattributes +9 -0
- .gitignore +181 -0
- LICENSE +21 -0
- README.md +168 -0
- config/column_rules_numeric.json +30 -0
- config/config_data.json +15 -0
- config/config_model.json +25 -0
- config/config_train.json +51 -0
- data/cat_vocab.json +3 -0
- data/numeric_vocab.json +3 -0
- data/photo_map.json +3 -0
- data/tabular_data.csv +3 -0
- data/tabular_meta.json +3 -0
- data/tabular_meta_numeric_stats.csv +3 -0
- example/input_card.json +3 -0
- example/input_card__masked.json +3 -0
- example/input_card__unmasked.json +3 -0
- example/output_card.json +3 -0
- example/output_card__acc.json +3 -0
- inference_create_input_card.py +318 -0
- inference_predict_output_card.py +545 -0
- model_weights/gemma3n_E2B_vision_only/config.json +3 -0
- model_weights/gemma3n_E2B_vision_only/model.safetensors +3 -0
- model_weights/gemma3n_E2B_vision_only/modeling_gemma3n.py +3 -0
- model_weights/gemma3n_E2B_vision_only/processor_config.json +3 -0
- model_weights/gemma3n_E2B_vision_only/tokenizer.json +3 -0
- model_weights/gemma3n_E2B_vision_only/tokenizer_config.json +3 -0
- model_weights/gemma3n_E2B_vision_only/vision_extractor_config.json +3 -0
- model_weights/soilformer_pretrain/hetero_epoch_200.pt +3 -0
- modelling/__init__.py +0 -0
- modelling/decode_categorical.py +423 -0
- modelling/decode_numeric.py +238 -0
- modelling/embed_categorical.py +322 -0
- modelling/embed_numeric.py +547 -0
- modelling/embed_vision_gemma3n.py +552 -0
- modelling/layer.py +353 -0
- modelling/loader.py +1025 -0
- modelling/soilformer.py +696 -0
- modelling/train.py +552 -0
- modelling/utils.py +132 -0
- requirements.txt +10 -0
- resources/arch.png +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto detect text files and perform LF normalization
|
| 2 |
+
* text=auto
|
| 3 |
+
model_weights/** filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
data/*.csv filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
data/*.json filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
example/*.json filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
resources/*.png filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
|
| 74 |
+
# PyBuilder
|
| 75 |
+
.pybuilder/
|
| 76 |
+
target/
|
| 77 |
+
|
| 78 |
+
# Jupyter Notebook
|
| 79 |
+
.ipynb_checkpoints
|
| 80 |
+
|
| 81 |
+
# IPython
|
| 82 |
+
profile_default/
|
| 83 |
+
ipython_config.py
|
| 84 |
+
|
| 85 |
+
# pyenv
|
| 86 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 88 |
+
# .python-version
|
| 89 |
+
|
| 90 |
+
# pipenv
|
| 91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 94 |
+
# install all needed dependencies.
|
| 95 |
+
#Pipfile.lock
|
| 96 |
+
|
| 97 |
+
# UV
|
| 98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
| 99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 100 |
+
# commonly ignored for libraries.
|
| 101 |
+
#uv.lock
|
| 102 |
+
|
| 103 |
+
# poetry
|
| 104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 106 |
+
# commonly ignored for libraries.
|
| 107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 108 |
+
#poetry.lock
|
| 109 |
+
|
| 110 |
+
# pdm
|
| 111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 112 |
+
#pdm.lock
|
| 113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 114 |
+
# in version control.
|
| 115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
| 116 |
+
.pdm.toml
|
| 117 |
+
.pdm-python
|
| 118 |
+
.pdm-build/
|
| 119 |
+
|
| 120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 121 |
+
__pypackages__/
|
| 122 |
+
|
| 123 |
+
# Celery stuff
|
| 124 |
+
celerybeat-schedule
|
| 125 |
+
celerybeat.pid
|
| 126 |
+
|
| 127 |
+
# SageMath parsed files
|
| 128 |
+
*.sage.py
|
| 129 |
+
|
| 130 |
+
# Environments
|
| 131 |
+
.env
|
| 132 |
+
.venv
|
| 133 |
+
env/
|
| 134 |
+
venv/
|
| 135 |
+
ENV/
|
| 136 |
+
env.bak/
|
| 137 |
+
venv.bak/
|
| 138 |
+
|
| 139 |
+
# Spyder project settings
|
| 140 |
+
.spyderproject
|
| 141 |
+
.spyproject
|
| 142 |
+
|
| 143 |
+
# Rope project settings
|
| 144 |
+
.ropeproject
|
| 145 |
+
|
| 146 |
+
# mkdocs documentation
|
| 147 |
+
/site
|
| 148 |
+
|
| 149 |
+
# mypy
|
| 150 |
+
.mypy_cache/
|
| 151 |
+
.dmypy.json
|
| 152 |
+
dmypy.json
|
| 153 |
+
|
| 154 |
+
# Pyre type checker
|
| 155 |
+
.pyre/
|
| 156 |
+
|
| 157 |
+
# pytype static type analyzer
|
| 158 |
+
.pytype/
|
| 159 |
+
|
| 160 |
+
# Cython debug symbols
|
| 161 |
+
cython_debug/
|
| 162 |
+
|
| 163 |
+
# PyCharm
|
| 164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 168 |
+
.idea/
|
| 169 |
+
|
| 170 |
+
# Ruff stuff:
|
| 171 |
+
.ruff_cache/
|
| 172 |
+
|
| 173 |
+
# PyPI configuration file
|
| 174 |
+
.pypirc
|
| 175 |
+
|
| 176 |
+
# Cursor
|
| 177 |
+
# Cursor is an AI-powered code editor.`.cursorignore` specifies files/directories to
|
| 178 |
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
| 179 |
+
# refer to https://docs.cursor.com/context/ignore-files
|
| 180 |
+
.cursorignore
|
| 181 |
+
.cursorindexingignore
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Kuangdai
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
library_name: pytorch
|
| 4 |
+
language:
|
| 5 |
+
- en
|
| 6 |
+
tags:
|
| 7 |
+
- soil
|
| 8 |
+
- soil-science
|
| 9 |
+
- earth-science
|
| 10 |
+
- environmental-science
|
| 11 |
+
- multimodal
|
| 12 |
+
- tabular
|
| 13 |
+
- transformer
|
| 14 |
+
- representation-learning
|
| 15 |
+
- masked-feature-modeling
|
| 16 |
+
- remote-sensing
|
| 17 |
+
- europe
|
| 18 |
+
datasets:
|
| 19 |
+
- earthroverprogram/lucas-mega
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
# SoilFormer
|
| 23 |
+
|
| 24 |
+
A multimodal tabular transformer trained on [LUCAS-MEGA](https://huggingface.co/datasets/earthroverprogram/lucas-mega).
|
| 25 |
+
|
| 26 |
+
[Manuscript](https://huggingface.co/datasets/earthroverprogram/lucas-mega/manuscript.pdf)
|
| 27 |
+
|
| 28 |
+
## Introduction
|
| 29 |
+
|
| 30 |
+
SoilFormer is a multimodal transformer for representation learning in soil–environment systems. It is trained on
|
| 31 |
+
LUCAS-MEGA, a large-scale dataset built from European soil and environmental observations, with the LUCAS soil survey as
|
| 32 |
+
its backbone. LUCAS-MEGA integrates heterogeneous sources into a machine-learning-ready sample–feature table, covering
|
| 33 |
+
numerical, categorical, textual, and visual modalities across soil physical, chemical, hydrological, environmental, and
|
| 34 |
+
site-related properties.
|
| 35 |
+
|
| 36 |
+
SoilFormer learns from partially observed multimodal samples using masked feature modeling. During training, a subset of
|
| 37 |
+
observed categorical and numerical features is masked, and the model reconstructs them from the remaining tabular and
|
| 38 |
+
visual context. The architecture combines grouped categorical embedding, grouped numerical encoding/decoding, vision
|
| 39 |
+
feature extraction and compression, transformer layers, and heteroscedastic prediction heads for uncertainty-aware
|
| 40 |
+
reconstruction.
|
| 41 |
+
|
| 42 |
+
<img src="resources/arch.png" alt="SoilFormer architecture" width="70%">
|
| 43 |
+
|
| 44 |
+
## Training
|
| 45 |
+
|
| 46 |
+
Train SoilFormer with:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
python modelling/train.py
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
Main configuration files:
|
| 53 |
+
|
| 54 |
+
* `config/config_model.json`: model architecture parameters, including embedding sizes, transformer layer settings,
|
| 55 |
+
decoder settings, dtype, and vision model configuration.
|
| 56 |
+
* `config/config_data.json`: data parameters, including CSV path, vocab paths, numeric statistics, photo mapping, image
|
| 57 |
+
root, train/eval split, batch size, and masking ratios.
|
| 58 |
+
* `config/config_train.json`: training hyperparameters, including runtime device, seed, optimizer settings, scheduler
|
| 59 |
+
settings, checkpoint behavior, loss options, logging, and output paths.
|
| 60 |
+
|
| 61 |
+
## Inference
|
| 62 |
+
|
| 63 |
+
Inference uses readable JSON input cards. The workflow is:
|
| 64 |
+
|
| 65 |
+
1. Create input cards from one dataset row.
|
| 66 |
+
2. Edit the masked card manually if desired.
|
| 67 |
+
3. Run model prediction from the edited card.
|
| 68 |
+
4. Optionally compare predictions against the unmasked answer card.
|
| 69 |
+
|
| 70 |
+
### 1. Create input cards
|
| 71 |
+
|
| 72 |
+
```bash
|
| 73 |
+
python create_input_card_from_dataset.py \
|
| 74 |
+
--row_index 10 \
|
| 75 |
+
--output example/input_card.json
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
This writes two files:
|
| 79 |
+
|
| 80 |
+
```text
|
| 81 |
+
example/input_card__unmasked.json
|
| 82 |
+
example/input_card__masked.json
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
The unmasked card contains the raw readable values from the CSV row. The masked card randomly replaces a fraction of
|
| 86 |
+
categorical and numeric values with `null`. Natural missing values remain as empty strings `""`, while active masks are
|
| 87 |
+
represented as `null`.
|
| 88 |
+
|
| 89 |
+
Default masking ratios are 0.15 for both categorical and numeric features:
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
python create_input_card_from_dataset.py \
|
| 93 |
+
--row_index 10 \
|
| 94 |
+
--output example/input_card.json \
|
| 95 |
+
--cat_mask_ratio 0.15 \
|
| 96 |
+
--num_mask_ratio 0.15 \
|
| 97 |
+
--seed 42
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
The card format is intentionally simple and user-editable. Users can copy this card as a template, replace the values
|
| 101 |
+
with their own soil sample information, and set variables to `null` to indicate which fields should be predicted during
|
| 102 |
+
inference:
|
| 103 |
+
|
| 104 |
+
```json
|
| 105 |
+
{
|
| 106 |
+
"categorical": {
|
| 107 |
+
"land_site:land_cover_primary": "B16: Cropland => Cereals => Maize",
|
| 108 |
+
"land_site:land_use_primary": null,
|
| 109 |
+
"soil_type:WRB_soil_group": "Cambisol",
|
| 110 |
+
"texture:ISSS_class": "silty clay",
|
| 111 |
+
"...": "..."
|
| 112 |
+
},
|
| 113 |
+
"numeric": {
|
| 114 |
+
"carbon:CaCO3_content (g/kg)": 7.0,
|
| 115 |
+
"carbon:SOC_saturation_ratio": 0.3647958934307098,
|
| 116 |
+
"geographic:latitude (deg)": 38.8513900000485,
|
| 117 |
+
"geographic:longitude (deg)": -9.29050000007487,
|
| 118 |
+
"mass_density:bulk_density (g/cm³)": null,
|
| 119 |
+
"...": "..."
|
| 120 |
+
},
|
| 121 |
+
"vision": {
|
| 122 |
+
"image_path_suffix": "relative/path/to/photo.jpg"
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
### 2. Run prediction
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
python inference_predict_output_card.py \
|
| 131 |
+
--checkpoint model_weights/soilformer_pretrain/hetero_epoch_200.pt \
|
| 132 |
+
--input_card example/input_card__masked.json \
|
| 133 |
+
--output example/output_card.json
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
This writes:
|
| 137 |
+
|
| 138 |
+
```text
|
| 139 |
+
example/output_card.json
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
`output_card.json` contains readable predictions:
|
| 143 |
+
|
| 144 |
+
* categorical outputs are decoded back to raw category labels;
|
| 145 |
+
* numeric outputs are converted from z-score space back to the original physical units;
|
| 146 |
+
* vision input is read from `vision.image_path_suffix` together with `photo_root` in `config/config_data.json`.
|
| 147 |
+
|
| 148 |
+
### 3. Evaluation with an answer card
|
| 149 |
+
|
| 150 |
+
```bash
|
| 151 |
+
python inference_predict_output_card.py \
|
| 152 |
+
--checkpoint model_weights/soilformer_pretrain/hetero_epoch_200.pt \
|
| 153 |
+
--input_card example/input_card__masked.json \
|
| 154 |
+
--answer_card example/input_card__unmasked.json \
|
| 155 |
+
--output example/output_card.json
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
This additionally writes:
|
| 159 |
+
|
| 160 |
+
```text
|
| 161 |
+
example/output_card__acc.json
|
| 162 |
+
```
|
| 163 |
+
|
| 164 |
+
When `--answer_card` is provided, `output_card__acc.json` reports reconstruction metrics over fields that are `null` in
|
| 165 |
+
the masked input card:
|
| 166 |
+
|
| 167 |
+
* categorical accuracy for masked categorical fields;
|
| 168 |
+
* numeric MAE for masked numeric fields, measured in the original feature units.
|
config/column_rules_numeric.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"texture:silt_percentage (%)": ">=0",
|
| 3 |
+
"chemical:pH_in_H2O": ">0",
|
| 4 |
+
"chemical:pH_in_CaCl2": ">0",
|
| 5 |
+
"carbon:organic_carbon_content (g/kg)": ">0",
|
| 6 |
+
"carbon:CaCO3_content (g/kg)": ">0",
|
| 7 |
+
"carbon:observed_vs_typical_soc_index_confidence_zone": "exclude",
|
| 8 |
+
"carbon:observed_vs_typical_soc_index": "exclude",
|
| 9 |
+
"fertility:N_extractable (g/kg)": ">0",
|
| 10 |
+
"fertility:K_extractable (mg/kg)": ">0",
|
| 11 |
+
"fertility:P_extractable (mg/kg)": ">0",
|
| 12 |
+
"fertility:P_available_stock (kg ha⁻¹)": ">0",
|
| 13 |
+
"land_degradation:soil_erosion_exceeding_10Mg_ha_yr (t ha⁻¹ yr⁻¹)": "exclude",
|
| 14 |
+
"crop_plant:cover_crop_fraction_5th_percentile (‱)": "exclude",
|
| 15 |
+
"crop_plant:cover_crop_fraction_95th_percentile (‱)": "exclude",
|
| 16 |
+
"mass_density:bulk_density_0_10cm (g/cm³)": ">0",
|
| 17 |
+
"mass_density:bulk_density_10_20cm (g/cm³)": ">0",
|
| 18 |
+
"mass_density:bulk_density (g/cm³)": ">0",
|
| 19 |
+
"biodiversity:land_use_change_pressure_index": "exclude",
|
| 20 |
+
"biodiversity:genetically_modified_organism_use_pressure_index": "exclude",
|
| 21 |
+
"trace_elements:Zn_concentration_5th_percentile (mg/kg)": "exclude",
|
| 22 |
+
"trace_elements:Zn_concentration_95th_percentile (mg/kg)": "exclude",
|
| 23 |
+
"trace_elements:As_concentration_std (log10 mg/kg)": "exclude",
|
| 24 |
+
"trace_elements:As_concentration_skewness": "exclude",
|
| 25 |
+
"trace_elements:As_concentration_kurtosis": "exclude",
|
| 26 |
+
"trace_elements:Hg_residual (µg/kg)": "exclude",
|
| 27 |
+
"climate:monthly_temperature_JAN_to_DEC (°C)": ">-100",
|
| 28 |
+
"climate:monthly_precipitation_JAN_to_DEC (mm)": ">-100",
|
| 29 |
+
"topography_geology:elevation (m)": "<4000"
|
| 30 |
+
}
|
config/config_data.json
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"data_csv_path": "data/tabular_data.csv",
|
| 3 |
+
"photo_map_path": "data/photo_map.json",
|
| 4 |
+
"cat_vocab_path": "data/cat_vocab.json",
|
| 5 |
+
"numeric_vocab_path": "data/numeric_vocab.json",
|
| 6 |
+
"numeric_stats_path": "data/tabular_meta_numeric_stats.csv",
|
| 7 |
+
"photo_root": "",
|
| 8 |
+
"image_size": 512,
|
| 9 |
+
"train_ratio": 0.8,
|
| 10 |
+
"train_eval_split_seed": 42,
|
| 11 |
+
"batch_size": 64,
|
| 12 |
+
"cat_mask_ratio": 0.15,
|
| 13 |
+
"num_mask_ratio": 0.15,
|
| 14 |
+
"active_mask_seed": 42
|
| 15 |
+
}
|
config/config_model.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dtype": "float32",
|
| 3 |
+
"tabular_meta": "data/tabular_meta.json",
|
| 4 |
+
"vision_model_dir": "./model_weights/gemma3n_E2B_vision_only",
|
| 5 |
+
"vision_num_output_tokens_reduced": 32,
|
| 6 |
+
"vision_num_heads_for_token_reduction": 4,
|
| 7 |
+
"vision_reducer_bottleneck_dim": 768,
|
| 8 |
+
"vision_reducer_project_back": false,
|
| 9 |
+
"cat_vocab_json": "data/cat_vocab.json",
|
| 10 |
+
"cat_hidden_size": 768,
|
| 11 |
+
"cat_decode_middle_size": null,
|
| 12 |
+
"numeric_vocab_json": "data/numeric_vocab.json",
|
| 13 |
+
"numeric_hidden_size": 768,
|
| 14 |
+
"numeric_encode_middle_size": null,
|
| 15 |
+
"numeric_decode_middle_size": null,
|
| 16 |
+
"layer_num_query_heads": 8,
|
| 17 |
+
"layer_num_kv_heads": 2,
|
| 18 |
+
"layer_head_dim": 128,
|
| 19 |
+
"layer_mlp_ratio": 1.5,
|
| 20 |
+
"layer_dropout": 0.1,
|
| 21 |
+
"layer_num_layers": 4,
|
| 22 |
+
"disable_tabular_attention_mask": true,
|
| 23 |
+
"cat_homoscedastic": false,
|
| 24 |
+
"num_homoscedastic": false
|
| 25 |
+
}
|
config/config_train.json
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"paths": {
|
| 3 |
+
"config_data_path": "config/config_data.json",
|
| 4 |
+
"config_model_path": "config/config_model.json",
|
| 5 |
+
"output_dir": "runs/soilformer_hetero"
|
| 6 |
+
},
|
| 7 |
+
"seed": {
|
| 8 |
+
"seed": 42,
|
| 9 |
+
"deterministic": true
|
| 10 |
+
},
|
| 11 |
+
"runtime": {
|
| 12 |
+
"device": "cuda",
|
| 13 |
+
"num_epochs": 500,
|
| 14 |
+
"init_weight_std": 0.02
|
| 15 |
+
},
|
| 16 |
+
"optimization": {
|
| 17 |
+
"lr": 1e-4,
|
| 18 |
+
"beta1": 0.9,
|
| 19 |
+
"beta2": 0.999,
|
| 20 |
+
"eps": 1e-8,
|
| 21 |
+
"weight_decay": 0.02,
|
| 22 |
+
"max_grad_norm": 1.0,
|
| 23 |
+
"scheduler": {
|
| 24 |
+
"type": "cosine",
|
| 25 |
+
"total_epochs": 500,
|
| 26 |
+
"eta_min": 2e-5,
|
| 27 |
+
"warmup_epochs": 5,
|
| 28 |
+
"warmup_start_factor": 0.1
|
| 29 |
+
}
|
| 30 |
+
},
|
| 31 |
+
"loss": {
|
| 32 |
+
"cat_s_bound": 2,
|
| 33 |
+
"num_s_bound": 4
|
| 34 |
+
},
|
| 35 |
+
"checkpoint": {
|
| 36 |
+
"resume_checkpoint_path": null,
|
| 37 |
+
"epochs_per_save": 100,
|
| 38 |
+
"max_saved_checkpoints": 5
|
| 39 |
+
},
|
| 40 |
+
"logging": {
|
| 41 |
+
"tqdm": true,
|
| 42 |
+
"wandb": {
|
| 43 |
+
"enabled": true,
|
| 44 |
+
"project": "soilformer",
|
| 45 |
+
"entity": "kuangdai-leng",
|
| 46 |
+
"run_name": "train-hetero",
|
| 47 |
+
"mode": "online",
|
| 48 |
+
"dir": null
|
| 49 |
+
}
|
| 50 |
+
}
|
| 51 |
+
}
|
data/cat_vocab.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da160e500f0bf01207642f39b666d84d2787fae0f8ec21bb630e10e079780843
|
| 3 |
+
size 14934
|
data/numeric_vocab.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ddcfc729da9e6f5830d58f6b53928a6fa6dcd108a0ddac3eb7fe67abed3dcadc
|
| 3 |
+
size 17492
|
data/photo_map.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:018b7f5baaa58e8e3e5e2c6cf98d02aa547a13c6de55f1628984010fc331235c
|
| 3 |
+
size 4651435
|
data/tabular_data.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4e87d387791bb95e3accab9afee8bda1e7e8722bad6e75d04c47a56787b24608
|
| 3 |
+
size 103677102
|
data/tabular_meta.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:de393215e2bbe46b111cc5b604b7fd04c14d28a634a52b06dcb94fd7073200eb
|
| 3 |
+
size 84654
|
data/tabular_meta_numeric_stats.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:73ef56458b2d7d6cbb730b153a7fd9f445dba4a96d6f29483364e38a9102c150
|
| 3 |
+
size 7714
|
example/input_card.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82977f50ba8a6a7d7542a4098434d232ad0feff5b1797b088a99b78504604420
|
| 3 |
+
size 6114
|
example/input_card__masked.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fadbf9afee9bd61fb87c4fd174306bcf6cae441e973b50afdbebd4bd433cb0be
|
| 3 |
+
size 5902
|
example/input_card__unmasked.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:82977f50ba8a6a7d7542a4098434d232ad0feff5b1797b088a99b78504604420
|
| 3 |
+
size 6114
|
example/output_card.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c11d1ae2c9ada543038ca38dcb5a2a496a8392ca52b07cea215cfd46f0172af0
|
| 3 |
+
size 7261
|
example/output_card__acc.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d76e4284a3c9e5be054cceb8f96bb5c20434dc1ea11f1904a2d1663d910efd4e
|
| 3 |
+
size 3388
|
inference_create_input_card.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import ast
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pandas as pd
|
| 10 |
+
|
| 11 |
+
from modelling.utils import load_json
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def to_jsonable(value: Any) -> Any:
|
| 15 |
+
if value is None:
|
| 16 |
+
return None
|
| 17 |
+
if isinstance(value, float) and pd.isna(value):
|
| 18 |
+
return None
|
| 19 |
+
if isinstance(value, np.generic):
|
| 20 |
+
return value.item()
|
| 21 |
+
return value
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_optional_int(value: Optional[str]) -> Optional[int]:
|
| 25 |
+
if value is None:
|
| 26 |
+
return None
|
| 27 |
+
value = str(value).strip().lower()
|
| 28 |
+
if value in {"", "none", "null", "random"}:
|
| 29 |
+
return None
|
| 30 |
+
return int(value)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def choose_row_index(num_rows: int, row_index: Optional[int], seed: int) -> int:
|
| 34 |
+
if num_rows <= 0:
|
| 35 |
+
raise RuntimeError("CSV has no rows")
|
| 36 |
+
if row_index is None:
|
| 37 |
+
return random.Random(seed).randrange(num_rows)
|
| 38 |
+
if row_index < 0 or row_index >= num_rows:
|
| 39 |
+
raise IndexError(f"row_index out of range: {row_index}; num_rows={num_rows}")
|
| 40 |
+
return row_index
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def validate_ratio(name: str, value: float) -> float:
|
| 44 |
+
value = float(value)
|
| 45 |
+
if not 0.0 <= value <= 1.0:
|
| 46 |
+
raise ValueError(f"{name} must be in [0, 1], got {value}")
|
| 47 |
+
return value
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_json_if_exists(path: Optional[str]) -> Optional[Dict[str, Any]]:
|
| 51 |
+
if not path:
|
| 52 |
+
return None
|
| 53 |
+
p = Path(path)
|
| 54 |
+
if not p.exists() or not p.is_file():
|
| 55 |
+
return None
|
| 56 |
+
return load_json(str(p))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def get_categorical_columns(config_data: Dict[str, Any]) -> list[str]:
|
| 60 |
+
cat_vocab = load_json_if_exists(config_data.get("cat_vocab_path"))
|
| 61 |
+
if not isinstance(cat_vocab, dict):
|
| 62 |
+
return []
|
| 63 |
+
return list(cat_vocab.keys())
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_numeric_columns(config_data: Dict[str, Any]) -> list[str]:
|
| 67 |
+
numeric_vocab = load_json_if_exists(config_data.get("numeric_vocab_path"))
|
| 68 |
+
if not isinstance(numeric_vocab, dict):
|
| 69 |
+
return []
|
| 70 |
+
|
| 71 |
+
columns: list[str] = []
|
| 72 |
+
for group in numeric_vocab.get("groups", []):
|
| 73 |
+
for name in group.get("feature_names", []):
|
| 74 |
+
columns.append(str(name))
|
| 75 |
+
return columns
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def get_vision_input(config_data: Dict[str, Any], row: Dict[str, Any]) -> Dict[str, Any]:
|
| 79 |
+
photo_map = load_json_if_exists(config_data.get("photo_map_path"))
|
| 80 |
+
id_column = str(config_data.get("id_column", "id"))
|
| 81 |
+
sample_id = row.get(id_column)
|
| 82 |
+
|
| 83 |
+
if not isinstance(photo_map, dict) or sample_id is None:
|
| 84 |
+
return {"image_path_suffix": ""}
|
| 85 |
+
|
| 86 |
+
relative_path = photo_map.get(sample_id)
|
| 87 |
+
if relative_path is None:
|
| 88 |
+
relative_path = photo_map.get(str(sample_id))
|
| 89 |
+
|
| 90 |
+
if relative_path is None or relative_path == "":
|
| 91 |
+
return {"image_path_suffix": ""}
|
| 92 |
+
|
| 93 |
+
return {"image_path_suffix": str(relative_path)}
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def parse_numeric_value(value: Any) -> Any:
|
| 97 |
+
"""
|
| 98 |
+
Convert known numeric CSV cells into readable JSON numbers.
|
| 99 |
+
|
| 100 |
+
Loader convention:
|
| 101 |
+
- missing numeric cell is ""
|
| 102 |
+
- scalar numeric cell is something like "12.3"
|
| 103 |
+
- vector numeric cell is something like "[1.2, 3.4]"
|
| 104 |
+
"""
|
| 105 |
+
value = to_jsonable(value)
|
| 106 |
+
|
| 107 |
+
if value == "" or value is None:
|
| 108 |
+
return ""
|
| 109 |
+
|
| 110 |
+
if isinstance(value, (int, float)) and not isinstance(value, bool):
|
| 111 |
+
return value
|
| 112 |
+
|
| 113 |
+
if isinstance(value, str):
|
| 114 |
+
s = value.strip()
|
| 115 |
+
if s == "":
|
| 116 |
+
return ""
|
| 117 |
+
|
| 118 |
+
if s.startswith("[") and s.endswith("]"):
|
| 119 |
+
parsed = ast.literal_eval(s)
|
| 120 |
+
if not isinstance(parsed, (list, tuple)):
|
| 121 |
+
raise ValueError(f"Expected numeric vector list, got: {value!r}")
|
| 122 |
+
return [float(x) for x in parsed]
|
| 123 |
+
|
| 124 |
+
return float(s)
|
| 125 |
+
|
| 126 |
+
return value
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def create_unmasked_card(
|
| 130 |
+
row: Dict[str, Any],
|
| 131 |
+
cat_columns: list[str],
|
| 132 |
+
numeric_columns: list[str],
|
| 133 |
+
vision: Dict[str, Any],
|
| 134 |
+
) -> Dict[str, Any]:
|
| 135 |
+
categorical = {col: row.get(col, "") for col in cat_columns if col in row}
|
| 136 |
+
numeric = {
|
| 137 |
+
col: parse_numeric_value(row.get(col, ""))
|
| 138 |
+
for col in numeric_columns
|
| 139 |
+
if col in row
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
return {
|
| 143 |
+
"categorical": categorical,
|
| 144 |
+
"numeric": numeric,
|
| 145 |
+
"vision": vision,
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def choose_mask_keys(values: Dict[str, Any], ratio: float, rng: random.Random) -> list[str]:
|
| 150 |
+
valid_keys = [k for k, v in values.items() if v not in ("", None)]
|
| 151 |
+
if ratio <= 0.0 or not valid_keys:
|
| 152 |
+
return []
|
| 153 |
+
|
| 154 |
+
k = int(round(len(valid_keys) * ratio))
|
| 155 |
+
k = max(0, min(k, len(valid_keys)))
|
| 156 |
+
if k == 0:
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
return rng.sample(valid_keys, k)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def create_masked_card(
|
| 163 |
+
unmasked_card: Dict[str, Any],
|
| 164 |
+
cat_mask_ratio: float,
|
| 165 |
+
num_mask_ratio: float,
|
| 166 |
+
seed: int,
|
| 167 |
+
) -> Dict[str, Any]:
|
| 168 |
+
rng = random.Random(seed)
|
| 169 |
+
masked = json.loads(json.dumps(unmasked_card, ensure_ascii=False))
|
| 170 |
+
|
| 171 |
+
cat_keys = choose_mask_keys(masked["categorical"], cat_mask_ratio, rng)
|
| 172 |
+
num_keys = choose_mask_keys(masked["numeric"], num_mask_ratio, rng)
|
| 173 |
+
|
| 174 |
+
for key in cat_keys:
|
| 175 |
+
masked["categorical"][key] = None
|
| 176 |
+
|
| 177 |
+
for key in num_keys:
|
| 178 |
+
masked["numeric"][key] = None
|
| 179 |
+
|
| 180 |
+
return masked
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def output_paths_from_given_name(given_name: str) -> tuple[Path, Path]:
|
| 184 |
+
path = Path(given_name)
|
| 185 |
+
base = path.with_suffix("") if path.suffix == ".json" else path
|
| 186 |
+
|
| 187 |
+
unmasked_path = base.with_name(base.name + "__unmasked.json")
|
| 188 |
+
masked_path = base.with_name(base.name + "__masked.json")
|
| 189 |
+
return unmasked_path, masked_path
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def create_cards(
|
| 193 |
+
config_data_path: str,
|
| 194 |
+
row_index: Optional[int],
|
| 195 |
+
seed: int,
|
| 196 |
+
cat_mask_ratio: float,
|
| 197 |
+
num_mask_ratio: float,
|
| 198 |
+
) -> tuple[Dict[str, Any], Dict[str, Any]]:
|
| 199 |
+
config_data = load_json(config_data_path)
|
| 200 |
+
csv_path = config_data["data_csv_path"]
|
| 201 |
+
|
| 202 |
+
# Match loader.py: empty cells remain "" instead of becoming NaN.
|
| 203 |
+
df = pd.read_csv(
|
| 204 |
+
csv_path,
|
| 205 |
+
keep_default_na=False,
|
| 206 |
+
na_filter=False,
|
| 207 |
+
low_memory=False,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
chosen_row_index = choose_row_index(
|
| 211 |
+
num_rows=len(df),
|
| 212 |
+
row_index=row_index,
|
| 213 |
+
seed=seed,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
row = {
|
| 217 |
+
str(k): to_jsonable(v)
|
| 218 |
+
for k, v in df.iloc[chosen_row_index].to_dict().items()
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
cat_columns = get_categorical_columns(config_data)
|
| 222 |
+
numeric_columns = get_numeric_columns(config_data)
|
| 223 |
+
vision = get_vision_input(config_data, row)
|
| 224 |
+
|
| 225 |
+
unmasked_card = create_unmasked_card(
|
| 226 |
+
row=row,
|
| 227 |
+
cat_columns=cat_columns,
|
| 228 |
+
numeric_columns=numeric_columns,
|
| 229 |
+
vision=vision,
|
| 230 |
+
)
|
| 231 |
+
masked_card = create_masked_card(
|
| 232 |
+
unmasked_card=unmasked_card,
|
| 233 |
+
cat_mask_ratio=cat_mask_ratio,
|
| 234 |
+
num_mask_ratio=num_mask_ratio,
|
| 235 |
+
seed=seed,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return unmasked_card, masked_card
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def save_json_pretty(obj: Dict[str, Any], path: Path) -> None:
|
| 242 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 243 |
+
with path.open("w", encoding="utf-8") as f:
|
| 244 |
+
json.dump(obj, f, ensure_ascii=False, indent=2)
|
| 245 |
+
f.write("\n")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def main() -> None:
|
| 249 |
+
parser = argparse.ArgumentParser(
|
| 250 |
+
description="Create readable/editable SoilFormer input cards from one CSV row."
|
| 251 |
+
)
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
"--config_data",
|
| 254 |
+
type=str,
|
| 255 |
+
default="config/config_data.json",
|
| 256 |
+
help="Path to config_data.json. Default: config/config_data.json",
|
| 257 |
+
)
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
"--row_index",
|
| 260 |
+
type=str,
|
| 261 |
+
default=None,
|
| 262 |
+
help="CSV row index. Use None/null/random or omit for a random row.",
|
| 263 |
+
)
|
| 264 |
+
parser.add_argument(
|
| 265 |
+
"--output",
|
| 266 |
+
type=str,
|
| 267 |
+
required=True,
|
| 268 |
+
help="Given output name. Writes given_name__unmasked.json and given_name__masked.json.",
|
| 269 |
+
)
|
| 270 |
+
parser.add_argument(
|
| 271 |
+
"--cat_mask_ratio",
|
| 272 |
+
type=float,
|
| 273 |
+
default=0.15,
|
| 274 |
+
help="Ratio of non-missing categorical features to mask. Default: 0.15",
|
| 275 |
+
)
|
| 276 |
+
parser.add_argument(
|
| 277 |
+
"--num_mask_ratio",
|
| 278 |
+
type=float,
|
| 279 |
+
default=0.15,
|
| 280 |
+
help="Ratio of non-missing numeric features to mask. Default: 0.15",
|
| 281 |
+
)
|
| 282 |
+
parser.add_argument(
|
| 283 |
+
"--seed",
|
| 284 |
+
type=int,
|
| 285 |
+
default=0,
|
| 286 |
+
help="Seed for random row selection and feature masking. Default: 42",
|
| 287 |
+
)
|
| 288 |
+
args = parser.parse_args()
|
| 289 |
+
|
| 290 |
+
cat_mask_ratio = validate_ratio("cat_mask_ratio", args.cat_mask_ratio)
|
| 291 |
+
num_mask_ratio = validate_ratio("num_mask_ratio", args.num_mask_ratio)
|
| 292 |
+
|
| 293 |
+
unmasked_card, masked_card = create_cards(
|
| 294 |
+
config_data_path=args.config_data,
|
| 295 |
+
row_index=parse_optional_int(args.row_index),
|
| 296 |
+
seed=args.seed,
|
| 297 |
+
cat_mask_ratio=cat_mask_ratio,
|
| 298 |
+
num_mask_ratio=num_mask_ratio,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
unmasked_path, masked_path = output_paths_from_given_name(args.output)
|
| 302 |
+
save_json_pretty(unmasked_card, unmasked_path)
|
| 303 |
+
save_json_pretty(masked_card, masked_path)
|
| 304 |
+
|
| 305 |
+
print(
|
| 306 |
+
json.dumps(
|
| 307 |
+
{
|
| 308 |
+
"status": "ok",
|
| 309 |
+
"unmasked_output": str(unmasked_path),
|
| 310 |
+
"masked_output": str(masked_path),
|
| 311 |
+
},
|
| 312 |
+
ensure_ascii=False,
|
| 313 |
+
)
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
main()
|
inference_predict_output_card.py
ADDED
|
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import ast
|
| 3 |
+
import json
|
| 4 |
+
import sys
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any, Dict, Optional, Tuple
|
| 8 |
+
from urllib.parse import urljoin
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import requests
|
| 13 |
+
import torch
|
| 14 |
+
from PIL import Image
|
| 15 |
+
from torchvision import transforms
|
| 16 |
+
|
| 17 |
+
# The script is intended to live one level above ./modelling.
|
| 18 |
+
# modelling/ modules still contain some legacy absolute imports, so expose the
|
| 19 |
+
# modelling directory on sys.path as well.
|
| 20 |
+
PROJECT_ROOT = Path(__file__).resolve().parent
|
| 21 |
+
MODELLING_DIR = PROJECT_ROOT / "modelling"
|
| 22 |
+
if str(MODELLING_DIR) not in sys.path:
|
| 23 |
+
sys.path.insert(0, str(MODELLING_DIR))
|
| 24 |
+
|
| 25 |
+
from modelling.soilformer import SoilFormer # noqa: E402
|
| 26 |
+
from modelling.utils import get_dtype, load_json # noqa: E402
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# -----------------------------------------------------------------------------
|
| 30 |
+
# JSON helpers
|
| 31 |
+
# -----------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
def load_card(path: str) -> Dict[str, Any]:
|
| 34 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 35 |
+
obj = json.load(f)
|
| 36 |
+
if not isinstance(obj, dict):
|
| 37 |
+
raise ValueError(f"Card must be a JSON object: {path}")
|
| 38 |
+
return obj
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def save_json_pretty(obj: Dict[str, Any], path: Path) -> None:
|
| 42 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 43 |
+
with path.open("w", encoding="utf-8") as f:
|
| 44 |
+
json.dump(obj, f, ensure_ascii=False, indent=2)
|
| 45 |
+
f.write("\n")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def to_jsonable(x: Any) -> Any:
|
| 49 |
+
if isinstance(x, np.generic):
|
| 50 |
+
return x.item()
|
| 51 |
+
if isinstance(x, np.ndarray):
|
| 52 |
+
return x.tolist()
|
| 53 |
+
if isinstance(x, torch.Tensor):
|
| 54 |
+
x = x.detach().cpu()
|
| 55 |
+
if x.ndim == 0:
|
| 56 |
+
return x.item()
|
| 57 |
+
return x.tolist()
|
| 58 |
+
if isinstance(x, dict):
|
| 59 |
+
return {str(k): to_jsonable(v) for k, v in x.items()}
|
| 60 |
+
if isinstance(x, (list, tuple)):
|
| 61 |
+
return [to_jsonable(v) for v in x]
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
# -----------------------------------------------------------------------------
|
| 66 |
+
# Runtime / model loading
|
| 67 |
+
# -----------------------------------------------------------------------------
|
| 68 |
+
|
| 69 |
+
def resolve_device(device_str: str) -> torch.device:
|
| 70 |
+
device_str = str(device_str).lower()
|
| 71 |
+
if device_str == "auto":
|
| 72 |
+
if torch.cuda.is_available():
|
| 73 |
+
return torch.device("cuda")
|
| 74 |
+
if torch.backends.mps.is_available():
|
| 75 |
+
return torch.device("mps")
|
| 76 |
+
return torch.device("cpu")
|
| 77 |
+
if device_str == "cuda":
|
| 78 |
+
if not torch.cuda.is_available():
|
| 79 |
+
raise RuntimeError("--device cuda requested, but CUDA is not available")
|
| 80 |
+
return torch.device("cuda")
|
| 81 |
+
if device_str == "mps":
|
| 82 |
+
if not torch.backends.mps.is_available():
|
| 83 |
+
raise RuntimeError("--device mps requested, but MPS is not available")
|
| 84 |
+
return torch.device("mps")
|
| 85 |
+
if device_str == "cpu":
|
| 86 |
+
return torch.device("cpu")
|
| 87 |
+
raise ValueError(f"Unsupported device: {device_str}")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_model(args: argparse.Namespace, config_model: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> SoilFormer:
|
| 91 |
+
print("[INFO] Initializing model...")
|
| 92 |
+
model = SoilFormer(config=config_model, device=str(device))
|
| 93 |
+
|
| 94 |
+
print("[INFO] Loading checkpoint...")
|
| 95 |
+
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
| 96 |
+
missing, unexpected = model.load_state_dict(
|
| 97 |
+
checkpoint["model_state_dict"], strict=False
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
non_vision_missing = [k for k in missing if not k.startswith("vision_extractor.")]
|
| 101 |
+
if len(non_vision_missing) > 0:
|
| 102 |
+
raise RuntimeError(
|
| 103 |
+
f"[ERROR] Missing non-vision keys detected: {non_vision_missing[:10]}"
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
print(f"[INFO] Missing keys (vision only): {len(missing)}")
|
| 107 |
+
print(f"[INFO] Unexpected keys: {len(unexpected)}")
|
| 108 |
+
|
| 109 |
+
model.to(device=device, dtype=dtype)
|
| 110 |
+
model.eval()
|
| 111 |
+
return model
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# -----------------------------------------------------------------------------
|
| 115 |
+
# Metadata loading
|
| 116 |
+
# -----------------------------------------------------------------------------
|
| 117 |
+
|
| 118 |
+
def load_metadata(config_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 119 |
+
cat_vocab = load_json(config_data["cat_vocab_path"])
|
| 120 |
+
numeric_vocab = load_json(config_data["numeric_vocab_path"])
|
| 121 |
+
|
| 122 |
+
stats_df = pd.read_csv(config_data["numeric_stats_path"])
|
| 123 |
+
numeric_stats = {}
|
| 124 |
+
for _, row in stats_df.iterrows():
|
| 125 |
+
col = row["column"]
|
| 126 |
+
mean = float(row["mean"])
|
| 127 |
+
std = float(row["std"])
|
| 128 |
+
if std == 0.0:
|
| 129 |
+
std = 1.0
|
| 130 |
+
numeric_stats[str(col)] = (mean, std)
|
| 131 |
+
|
| 132 |
+
cat_columns = list(cat_vocab.keys())
|
| 133 |
+
cat_mask_local_ids = [int(cat_vocab[col]["mask_local_id"]) for col in cat_columns]
|
| 134 |
+
|
| 135 |
+
id_to_label_by_col = {}
|
| 136 |
+
for col in cat_columns:
|
| 137 |
+
label2id = cat_vocab[col]["label2id"]
|
| 138 |
+
id_to_label_by_col[col] = {int(v): str(k) for k, v in label2id.items()}
|
| 139 |
+
|
| 140 |
+
return {
|
| 141 |
+
"cat_vocab": cat_vocab,
|
| 142 |
+
"numeric_vocab": numeric_vocab,
|
| 143 |
+
"numeric_stats": numeric_stats,
|
| 144 |
+
"cat_columns": cat_columns,
|
| 145 |
+
"cat_mask_local_ids": cat_mask_local_ids,
|
| 146 |
+
"id_to_label_by_col": id_to_label_by_col,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# -----------------------------------------------------------------------------
|
| 151 |
+
# Image handling, matching loader.py behavior
|
| 152 |
+
# -----------------------------------------------------------------------------
|
| 153 |
+
|
| 154 |
+
class CenterSquareCrop:
|
| 155 |
+
def __call__(self, img: Image.Image) -> Image.Image:
|
| 156 |
+
w, h = img.size
|
| 157 |
+
if w == h:
|
| 158 |
+
return img
|
| 159 |
+
if w > h:
|
| 160 |
+
left = (w - h) // 2
|
| 161 |
+
return img.crop((left, 0, left + h, h))
|
| 162 |
+
top = (h - w) // 2
|
| 163 |
+
return img.crop((0, top, w, top + w))
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def build_image_transform(image_size: int):
|
| 167 |
+
return transforms.Compose([
|
| 168 |
+
CenterSquareCrop(),
|
| 169 |
+
transforms.Resize((image_size, image_size)),
|
| 170 |
+
transforms.ToTensor(),
|
| 171 |
+
])
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def join_photo_root(photo_root: str, relative_path: str) -> str:
|
| 175 |
+
if photo_root.startswith("http://") or photo_root.startswith("https://"):
|
| 176 |
+
return urljoin(photo_root.rstrip("/") + "/", relative_path)
|
| 177 |
+
return photo_root.rstrip("/") + "/" + relative_path.lstrip("/")
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def load_image_tensor(image_path: str, image_size: int) -> torch.Tensor:
|
| 181 |
+
if image_path.startswith("http://") or image_path.startswith("https://"):
|
| 182 |
+
resp = requests.get(image_path, timeout=(3, 10))
|
| 183 |
+
resp.raise_for_status()
|
| 184 |
+
img = Image.open(BytesIO(resp.content)).convert("RGB")
|
| 185 |
+
else:
|
| 186 |
+
img = Image.open(image_path).convert("RGB")
|
| 187 |
+
return build_image_transform(image_size)(img)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# -----------------------------------------------------------------------------
|
| 191 |
+
# Tensorization from readable input card
|
| 192 |
+
# -----------------------------------------------------------------------------
|
| 193 |
+
|
| 194 |
+
def is_masked_or_missing(value: Any) -> bool:
|
| 195 |
+
return value is None or value == ""
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def parse_numeric_card_value(value: Any, n_in: int) -> Tuple[list[float], bool]:
|
| 199 |
+
if value is None or value == "":
|
| 200 |
+
return [0.0] * n_in, False
|
| 201 |
+
|
| 202 |
+
if n_in == 1:
|
| 203 |
+
if isinstance(value, list):
|
| 204 |
+
if len(value) != 1:
|
| 205 |
+
raise ValueError(f"Expected scalar or length-1 list for n_in=1, got {value!r}")
|
| 206 |
+
return [float(value[0])], True
|
| 207 |
+
return [float(value)], True
|
| 208 |
+
|
| 209 |
+
if isinstance(value, str):
|
| 210 |
+
parsed = ast.literal_eval(value)
|
| 211 |
+
else:
|
| 212 |
+
parsed = value
|
| 213 |
+
|
| 214 |
+
if not isinstance(parsed, (list, tuple)):
|
| 215 |
+
raise ValueError(f"Expected list-like numeric vector for n_in={n_in}, got {value!r}")
|
| 216 |
+
if len(parsed) != n_in:
|
| 217 |
+
raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(parsed)}")
|
| 218 |
+
return [float(v) for v in parsed], True
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def tensorize_card(
|
| 222 |
+
input_card: Dict[str, Any],
|
| 223 |
+
config_data: Dict[str, Any],
|
| 224 |
+
meta: Dict[str, Any],
|
| 225 |
+
) -> Dict[str, Any]:
|
| 226 |
+
categorical = input_card.get("categorical", {})
|
| 227 |
+
numeric = input_card.get("numeric", {})
|
| 228 |
+
vision = input_card.get("vision", {})
|
| 229 |
+
|
| 230 |
+
if not isinstance(categorical, dict):
|
| 231 |
+
raise ValueError("input_card['categorical'] must be an object")
|
| 232 |
+
if not isinstance(numeric, dict):
|
| 233 |
+
raise ValueError("input_card['numeric'] must be an object")
|
| 234 |
+
if not isinstance(vision, dict):
|
| 235 |
+
vision = {}
|
| 236 |
+
|
| 237 |
+
# Categorical: raw label -> local id, null/"" -> mask id and invalid.
|
| 238 |
+
cat_ids = []
|
| 239 |
+
cat_valids = []
|
| 240 |
+
for col, mask_id in zip(meta["cat_columns"], meta["cat_mask_local_ids"]):
|
| 241 |
+
value = categorical.get(col, "")
|
| 242 |
+
if is_masked_or_missing(value):
|
| 243 |
+
cat_ids.append(mask_id)
|
| 244 |
+
cat_valids.append(False)
|
| 245 |
+
else:
|
| 246 |
+
label2id = meta["cat_vocab"][col]["label2id"]
|
| 247 |
+
if value not in label2id:
|
| 248 |
+
raise KeyError(f"Unknown categorical value: column={col}, value={value!r}")
|
| 249 |
+
cat_ids.append(int(label2id[value]))
|
| 250 |
+
cat_valids.append(True)
|
| 251 |
+
|
| 252 |
+
cat_local_ids = torch.tensor([cat_ids], dtype=torch.long)
|
| 253 |
+
cat_valid_positions = torch.tensor([cat_valids], dtype=torch.bool)
|
| 254 |
+
|
| 255 |
+
# Numeric: raw actual units -> z-score grouped tensors.
|
| 256 |
+
numeric_values_by_nin = {}
|
| 257 |
+
numeric_valid_positions_by_nin = {}
|
| 258 |
+
|
| 259 |
+
for group in meta["numeric_vocab"]["groups"]:
|
| 260 |
+
n_in = int(group["n_in"])
|
| 261 |
+
values = []
|
| 262 |
+
valids = []
|
| 263 |
+
for feat in group["feature_names"]:
|
| 264 |
+
feat = str(feat)
|
| 265 |
+
raw_value = numeric.get(feat, "")
|
| 266 |
+
parsed, is_valid = parse_numeric_card_value(raw_value, n_in)
|
| 267 |
+
if is_valid:
|
| 268 |
+
mean, std = meta["numeric_stats"][feat]
|
| 269 |
+
parsed = [(v - mean) / std for v in parsed]
|
| 270 |
+
values.append(parsed)
|
| 271 |
+
valids.append(is_valid)
|
| 272 |
+
|
| 273 |
+
numeric_values_by_nin[n_in] = torch.tensor([values], dtype=torch.float32)
|
| 274 |
+
numeric_valid_positions_by_nin[n_in] = torch.tensor([valids], dtype=torch.bool)
|
| 275 |
+
|
| 276 |
+
# Vision: readable card stores suffix only. Load/transform here.
|
| 277 |
+
image_size = int(config_data["image_size"])
|
| 278 |
+
image_path_suffix = vision.get("image_path_suffix", "")
|
| 279 |
+
if image_path_suffix is None or image_path_suffix == "":
|
| 280 |
+
pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32)
|
| 281 |
+
vision_valid_positions = torch.tensor([False], dtype=torch.bool)
|
| 282 |
+
else:
|
| 283 |
+
image_path = join_photo_root(str(config_data["photo_root"]), str(image_path_suffix))
|
| 284 |
+
try:
|
| 285 |
+
image = load_image_tensor(image_path, image_size=image_size)
|
| 286 |
+
pixel_values = image.unsqueeze(0)
|
| 287 |
+
vision_valid_positions = torch.tensor([True], dtype=torch.bool)
|
| 288 |
+
except Exception as exc:
|
| 289 |
+
print(f"[WARN] Could not load image; using zero vision input: {exc}")
|
| 290 |
+
pixel_values = torch.zeros(1, 3, image_size, image_size, dtype=torch.float32)
|
| 291 |
+
vision_valid_positions = torch.tensor([False], dtype=torch.bool)
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"cat_local_ids": cat_local_ids,
|
| 295 |
+
"cat_valid_positions": cat_valid_positions,
|
| 296 |
+
"numeric_values_by_nin": numeric_values_by_nin,
|
| 297 |
+
"numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
|
| 298 |
+
"pixel_values": pixel_values,
|
| 299 |
+
"vision_valid_positions": vision_valid_positions,
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def move_batch_to_device(batch: Dict[str, Any], device: torch.device, dtype: torch.dtype) -> Dict[str, Any]:
|
| 304 |
+
out = {}
|
| 305 |
+
for key, value in batch.items():
|
| 306 |
+
if isinstance(value, torch.Tensor):
|
| 307 |
+
if value.dtype.is_floating_point:
|
| 308 |
+
out[key] = value.to(device=device, dtype=dtype)
|
| 309 |
+
else:
|
| 310 |
+
out[key] = value.to(device=device)
|
| 311 |
+
elif isinstance(value, dict):
|
| 312 |
+
sub = {}
|
| 313 |
+
for k, v in value.items():
|
| 314 |
+
if isinstance(v, torch.Tensor):
|
| 315 |
+
if v.dtype.is_floating_point:
|
| 316 |
+
sub[k] = v.to(device=device, dtype=dtype)
|
| 317 |
+
else:
|
| 318 |
+
sub[k] = v.to(device=device)
|
| 319 |
+
else:
|
| 320 |
+
sub[k] = v
|
| 321 |
+
out[key] = sub
|
| 322 |
+
else:
|
| 323 |
+
out[key] = value
|
| 324 |
+
return out
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# -----------------------------------------------------------------------------
|
| 328 |
+
# Decoding model outputs to readable card
|
| 329 |
+
# -----------------------------------------------------------------------------
|
| 330 |
+
|
| 331 |
+
def denormalize_numeric(values_z: list[float], mean: float, std: float) -> list[float]:
|
| 332 |
+
return [float(v) * float(std) + float(mean) for v in values_z]
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def decode_outputs(
|
| 336 |
+
cat_logits_padded: torch.Tensor,
|
| 337 |
+
valid_class_mask: torch.Tensor,
|
| 338 |
+
value_by_nin: Dict[int, torch.Tensor],
|
| 339 |
+
meta: Dict[str, Any],
|
| 340 |
+
) -> Dict[str, Any]:
|
| 341 |
+
cat_logits = cat_logits_padded.detach().float().cpu()
|
| 342 |
+
valid_class_mask = valid_class_mask.detach().cpu().bool()
|
| 343 |
+
|
| 344 |
+
categorical_out = {}
|
| 345 |
+
for m, col in enumerate(meta["cat_columns"]):
|
| 346 |
+
cm = int(valid_class_mask[m].sum().item())
|
| 347 |
+
logits = cat_logits[0, m, :cm]
|
| 348 |
+
probs = torch.softmax(logits, dim=-1)
|
| 349 |
+
pred_id = int(torch.argmax(probs).item())
|
| 350 |
+
pred_label = meta["id_to_label_by_col"][col].get(pred_id, str(pred_id))
|
| 351 |
+
categorical_out[col] = pred_label
|
| 352 |
+
|
| 353 |
+
numeric_out = {}
|
| 354 |
+
for group in meta["numeric_vocab"]["groups"]:
|
| 355 |
+
n_in = int(group["n_in"])
|
| 356 |
+
preds_z = value_by_nin[n_in].detach().float().cpu()[0] # [V, n_in]
|
| 357 |
+
for v_idx, feat in enumerate(group["feature_names"]):
|
| 358 |
+
feat = str(feat)
|
| 359 |
+
mean, std = meta["numeric_stats"][feat]
|
| 360 |
+
raw_pred_values = denormalize_numeric(preds_z[v_idx].tolist(), mean, std)
|
| 361 |
+
if n_in == 1:
|
| 362 |
+
numeric_out[feat] = raw_pred_values[0]
|
| 363 |
+
else:
|
| 364 |
+
numeric_out[feat] = raw_pred_values
|
| 365 |
+
|
| 366 |
+
return {
|
| 367 |
+
"categorical": categorical_out,
|
| 368 |
+
"numeric": numeric_out,
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# -----------------------------------------------------------------------------
|
| 373 |
+
# Accuracy / MAE analysis
|
| 374 |
+
# -----------------------------------------------------------------------------
|
| 375 |
+
|
| 376 |
+
def masked_feature_names(input_card: Dict[str, Any], section: str) -> list[str]:
|
| 377 |
+
values = input_card.get(section, {})
|
| 378 |
+
if not isinstance(values, dict):
|
| 379 |
+
return []
|
| 380 |
+
return [k for k, v in values.items() if v is None]
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
def numeric_abs_errors(pred_value: Any, answer_value: Any) -> list[float]:
|
| 384 |
+
if answer_value is None or answer_value == "":
|
| 385 |
+
return []
|
| 386 |
+
if pred_value is None or pred_value == "":
|
| 387 |
+
return []
|
| 388 |
+
|
| 389 |
+
if isinstance(answer_value, str):
|
| 390 |
+
s = answer_value.strip()
|
| 391 |
+
if s == "":
|
| 392 |
+
return []
|
| 393 |
+
if s.startswith("[") and s.endswith("]"):
|
| 394 |
+
answer_value = [float(x) for x in ast.literal_eval(s)]
|
| 395 |
+
else:
|
| 396 |
+
answer_value = float(s)
|
| 397 |
+
|
| 398 |
+
if isinstance(pred_value, str):
|
| 399 |
+
s = pred_value.strip()
|
| 400 |
+
if s.startswith("[") and s.endswith("]"):
|
| 401 |
+
pred_value = [float(x) for x in ast.literal_eval(s)]
|
| 402 |
+
else:
|
| 403 |
+
pred_value = float(s)
|
| 404 |
+
|
| 405 |
+
if isinstance(answer_value, (list, tuple)):
|
| 406 |
+
if not isinstance(pred_value, (list, tuple)):
|
| 407 |
+
return []
|
| 408 |
+
if len(pred_value) != len(answer_value):
|
| 409 |
+
return []
|
| 410 |
+
return [abs(float(p) - float(a)) for p, a in zip(pred_value, answer_value)]
|
| 411 |
+
|
| 412 |
+
return [abs(float(pred_value) - float(answer_value))]
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def evaluate_against_answer(
|
| 416 |
+
input_card: Dict[str, Any],
|
| 417 |
+
output_card: Dict[str, Any],
|
| 418 |
+
answer_card: Dict[str, Any],
|
| 419 |
+
) -> Dict[str, Any]:
|
| 420 |
+
cat_masked = masked_feature_names(input_card, "categorical")
|
| 421 |
+
num_masked = masked_feature_names(input_card, "numeric")
|
| 422 |
+
|
| 423 |
+
cat_details = {}
|
| 424 |
+
correct = 0
|
| 425 |
+
total = 0
|
| 426 |
+
for feat in cat_masked:
|
| 427 |
+
answer = answer_card.get("categorical", {}).get(feat)
|
| 428 |
+
pred = output_card.get("categorical", {}).get(feat)
|
| 429 |
+
if answer is None or answer == "":
|
| 430 |
+
continue
|
| 431 |
+
is_correct = pred == answer
|
| 432 |
+
cat_details[feat] = {
|
| 433 |
+
"predicted": pred,
|
| 434 |
+
"answer": answer,
|
| 435 |
+
"correct": bool(is_correct),
|
| 436 |
+
}
|
| 437 |
+
correct += int(is_correct)
|
| 438 |
+
total += 1
|
| 439 |
+
|
| 440 |
+
num_details = {}
|
| 441 |
+
abs_errors_all = []
|
| 442 |
+
for feat in num_masked:
|
| 443 |
+
answer = answer_card.get("numeric", {}).get(feat)
|
| 444 |
+
pred = output_card.get("numeric", {}).get(feat)
|
| 445 |
+
errors = numeric_abs_errors(pred, answer)
|
| 446 |
+
if not errors:
|
| 447 |
+
continue
|
| 448 |
+
mae = sum(errors) / len(errors)
|
| 449 |
+
num_details[feat] = {
|
| 450 |
+
"predicted": pred,
|
| 451 |
+
"answer": answer,
|
| 452 |
+
"absolute_error": errors[0] if len(errors) == 1 else errors,
|
| 453 |
+
"mae": mae,
|
| 454 |
+
}
|
| 455 |
+
abs_errors_all.extend(errors)
|
| 456 |
+
|
| 457 |
+
return {
|
| 458 |
+
"categorical": {
|
| 459 |
+
"accuracy": None if total == 0 else correct / total,
|
| 460 |
+
"correct": correct,
|
| 461 |
+
"total": total,
|
| 462 |
+
"details": cat_details,
|
| 463 |
+
},
|
| 464 |
+
"numeric": {
|
| 465 |
+
"mae": None if len(abs_errors_all) == 0 else sum(abs_errors_all) / len(abs_errors_all),
|
| 466 |
+
"count": len(abs_errors_all),
|
| 467 |
+
"details": num_details,
|
| 468 |
+
},
|
| 469 |
+
"note": "Metrics are computed only on fields that are null in input_card. Natural missing values \"\" are ignored.",
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def acc_path_from_output(output: str) -> Path:
|
| 474 |
+
path = Path(output)
|
| 475 |
+
if path.suffix == ".json":
|
| 476 |
+
base = path.with_suffix("")
|
| 477 |
+
else:
|
| 478 |
+
base = path
|
| 479 |
+
return base.with_name(base.name + "__acc.json")
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
# -----------------------------------------------------------------------------
|
| 483 |
+
# CLI
|
| 484 |
+
# -----------------------------------------------------------------------------
|
| 485 |
+
|
| 486 |
+
def main() -> None:
|
| 487 |
+
parser = argparse.ArgumentParser(description="Run SoilFormer inference from a readable input card.")
|
| 488 |
+
parser.add_argument("--input_card", type=str, required=True)
|
| 489 |
+
parser.add_argument("--output", type=str, required=True)
|
| 490 |
+
parser.add_argument("--answer_card", type=str, default=None)
|
| 491 |
+
parser.add_argument("--checkpoint", type=str, required=True)
|
| 492 |
+
parser.add_argument("--config_data", type=str, default="config/config_data.json")
|
| 493 |
+
parser.add_argument("--config_model", type=str, default="config/config_model.json")
|
| 494 |
+
parser.add_argument("--device", type=str, default="auto", choices=["auto", "cuda", "mps", "cpu"])
|
| 495 |
+
args = parser.parse_args()
|
| 496 |
+
|
| 497 |
+
config_data = load_json(args.config_data)
|
| 498 |
+
config_model = load_json(args.config_model)
|
| 499 |
+
dtype = get_dtype(config_model.get("dtype", "bfloat16"))
|
| 500 |
+
device = resolve_device(args.device)
|
| 501 |
+
|
| 502 |
+
meta = load_metadata(config_data)
|
| 503 |
+
input_card = load_card(args.input_card)
|
| 504 |
+
batch = tensorize_card(input_card=input_card, config_data=config_data, meta=meta)
|
| 505 |
+
batch = move_batch_to_device(batch, device=device, dtype=dtype)
|
| 506 |
+
|
| 507 |
+
model = load_model(args=args, config_model=config_model, device=device, dtype=dtype)
|
| 508 |
+
|
| 509 |
+
with torch.no_grad():
|
| 510 |
+
cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model(
|
| 511 |
+
cat_local_ids=batch["cat_local_ids"],
|
| 512 |
+
numeric_values_by_nin=batch["numeric_values_by_nin"],
|
| 513 |
+
cat_valid_positions=batch["cat_valid_positions"],
|
| 514 |
+
numeric_valid_positions_by_nin=batch["numeric_valid_positions_by_nin"],
|
| 515 |
+
pixel_values=batch["pixel_values"],
|
| 516 |
+
vision_valid_positions=batch["vision_valid_positions"],
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
output_card = decode_outputs(
|
| 520 |
+
cat_logits_padded=cat_logits_padded,
|
| 521 |
+
valid_class_mask=valid_class_mask,
|
| 522 |
+
value_by_nin=value_by_nin,
|
| 523 |
+
meta=meta,
|
| 524 |
+
)
|
| 525 |
+
|
| 526 |
+
save_json_pretty(to_jsonable(output_card), Path(args.output))
|
| 527 |
+
|
| 528 |
+
result = {"status": "ok", "output": args.output}
|
| 529 |
+
|
| 530 |
+
if args.answer_card:
|
| 531 |
+
answer_card = load_card(args.answer_card)
|
| 532 |
+
acc_card = evaluate_against_answer(
|
| 533 |
+
input_card=input_card,
|
| 534 |
+
output_card=output_card,
|
| 535 |
+
answer_card=answer_card,
|
| 536 |
+
)
|
| 537 |
+
acc_path = acc_path_from_output(args.output)
|
| 538 |
+
save_json_pretty(to_jsonable(acc_card), acc_path)
|
| 539 |
+
result["acc_output"] = str(acc_path)
|
| 540 |
+
|
| 541 |
+
print(json.dumps(result, ensure_ascii=False))
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
if __name__ == "__main__":
|
| 545 |
+
main()
|
model_weights/gemma3n_E2B_vision_only/config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df49c2835315d4de6753bea989198e66157d84aa831738227f3bc705eab2d746
|
| 3 |
+
size 4455
|
model_weights/gemma3n_E2B_vision_only/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:eed8742f2e68b0d28bac29ee591a97e6738b6d040e0a5b69d270fca1d1453e20
|
| 3 |
+
size 597245920
|
model_weights/gemma3n_E2B_vision_only/modeling_gemma3n.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78b0b5d14177913d7956279f7a08b62f45f5b0ca6ab1993507fc653ad9579b0c
|
| 3 |
+
size 114392
|
model_weights/gemma3n_E2B_vision_only/processor_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a3f52ae9fb2eeed632fc99f14fa8b4405b17cd4b760a369cddf366f9ccf6855b
|
| 3 |
+
size 2262
|
model_weights/gemma3n_E2B_vision_only/tokenizer.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fad9b5f6f930b43d292eb3c56c176a69292850ddd0abc02d9ea1dac3292c87a
|
| 3 |
+
size 33442428
|
model_weights/gemma3n_E2B_vision_only/tokenizer_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:10c688d1767007b8614f275427198205507d941aefa6ae63c3e429ef87de7999
|
| 3 |
+
size 936
|
model_weights/gemma3n_E2B_vision_only/vision_extractor_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ea31eaf2aec2df075d62a4bca2209763e97a0141122257b07e62fe79e3cf4564
|
| 3 |
+
size 156
|
model_weights/soilformer_pretrain/hetero_epoch_200.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:057cd623e72bbf477bd46f346506acfd5741c2b57326d2bc73e723ac3ea949fc
|
| 3 |
+
size 276126967
|
modelling/__init__.py
ADDED
|
File without changes
|
modelling/decode_categorical.py
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# decode_categorical.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Categorical decoder for tabular transformer.
|
| 6 |
+
|
| 7 |
+
Design (column-wise heads):
|
| 8 |
+
- Each categorical column corresponds to exactly 1 token.
|
| 9 |
+
- Each column has its own classifier head:
|
| 10 |
+
hidden_size -> num_classes[col]
|
| 11 |
+
Optionally with a small MLP:
|
| 12 |
+
hidden_size -> middle_size -> num_classes[col]
|
| 13 |
+
|
| 14 |
+
No loss is included here (caller will apply CrossEntropyLoss).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
|
| 22 |
+
from utils import load_json, GroupedMLP
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ============================================================
|
| 26 |
+
# Small head builder
|
| 27 |
+
# ============================================================
|
| 28 |
+
|
| 29 |
+
def _make_head(
|
| 30 |
+
hidden_size: int,
|
| 31 |
+
num_classes: int,
|
| 32 |
+
middle_size: Optional[int],
|
| 33 |
+
bias: bool = True,
|
| 34 |
+
) -> nn.Module:
|
| 35 |
+
"""
|
| 36 |
+
Build a lightweight per-column classifier head.
|
| 37 |
+
"""
|
| 38 |
+
if middle_size is None:
|
| 39 |
+
return nn.Linear(hidden_size, num_classes, bias=bias)
|
| 40 |
+
|
| 41 |
+
return nn.Sequential(
|
| 42 |
+
nn.Linear(hidden_size, middle_size, bias=bias),
|
| 43 |
+
nn.GELU(),
|
| 44 |
+
nn.Linear(middle_size, num_classes, bias=bias),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ============================================================
|
| 49 |
+
# Decoder
|
| 50 |
+
# ============================================================
|
| 51 |
+
|
| 52 |
+
class CategoricalDecoder(nn.Module):
|
| 53 |
+
"""
|
| 54 |
+
Column-wise categorical decoder.
|
| 55 |
+
|
| 56 |
+
Design:
|
| 57 |
+
- Each categorical column corresponds to exactly one token.
|
| 58 |
+
- Each column has its own classifier head:
|
| 59 |
+
hidden_size -> num_classes[col]
|
| 60 |
+
Optionally with a small MLP:
|
| 61 |
+
hidden_size -> middle_size -> num_classes[col]
|
| 62 |
+
|
| 63 |
+
- In addition, the decoder predicts a per-sample, per-column
|
| 64 |
+
log-variance term `s` used for heteroscedastic loss weighting.
|
| 65 |
+
|
| 66 |
+
Input:
|
| 67 |
+
x_cat_tokens: [B, M, H]
|
| 68 |
+
B = batch size
|
| 69 |
+
M = number of categorical columns (ordered by col_id)
|
| 70 |
+
H = hidden size
|
| 71 |
+
|
| 72 |
+
Outputs:
|
| 73 |
+
|
| 74 |
+
Case 1 (return_padded=False):
|
| 75 |
+
logits_list: List[Tensor] length M
|
| 76 |
+
logits_list[m]: [B, num_classes[m]]
|
| 77 |
+
|
| 78 |
+
s: [B, M]
|
| 79 |
+
Predicted log-variance per sample and column:
|
| 80 |
+
s[b, m] = log sigma^2_{b,m}
|
| 81 |
+
Intended for heteroscedastic loss weighting.
|
| 82 |
+
|
| 83 |
+
Case 2 (return_padded=True):
|
| 84 |
+
logits_padded: [B, M, Cmax]
|
| 85 |
+
Logits padded to the maximum class count across columns.
|
| 86 |
+
|
| 87 |
+
s: [B, M]
|
| 88 |
+
Same uncertainty prediction as above.
|
| 89 |
+
|
| 90 |
+
valid_mask: [M, Cmax]
|
| 91 |
+
True for valid class indices for each column.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
hidden_size: int,
|
| 97 |
+
cat_vocab_json: str,
|
| 98 |
+
middle_size: Optional[int] = None,
|
| 99 |
+
bias: bool = True,
|
| 100 |
+
homoscedastic: bool = True,
|
| 101 |
+
):
|
| 102 |
+
super().__init__()
|
| 103 |
+
|
| 104 |
+
spec = load_json(cat_vocab_json)
|
| 105 |
+
items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
|
| 106 |
+
|
| 107 |
+
col_ids: List[int] = []
|
| 108 |
+
num_classes: List[int] = []
|
| 109 |
+
|
| 110 |
+
for _, val in items:
|
| 111 |
+
col_ids.append(int(val["col_id"]))
|
| 112 |
+
num_classes.append(int(val["num_classes"]))
|
| 113 |
+
|
| 114 |
+
self.hidden_size = int(hidden_size)
|
| 115 |
+
self.num_cols = len(num_classes)
|
| 116 |
+
self.middle_size = middle_size
|
| 117 |
+
self.homoscedastic = bool(homoscedastic)
|
| 118 |
+
|
| 119 |
+
# Buffers for debugging / validation / optional padded output
|
| 120 |
+
self.register_buffer("cat_col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True) # [M]
|
| 121 |
+
self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True) # [M]
|
| 122 |
+
|
| 123 |
+
# Build per-column heads
|
| 124 |
+
heads = []
|
| 125 |
+
for c in num_classes:
|
| 126 |
+
head = _make_head(self.hidden_size, c, middle_size, bias=bias)
|
| 127 |
+
heads.append(head)
|
| 128 |
+
|
| 129 |
+
self.heads = nn.ModuleList(heads)
|
| 130 |
+
|
| 131 |
+
if self.homoscedastic:
|
| 132 |
+
self.s_param = nn.Parameter(torch.zeros(self.num_cols))
|
| 133 |
+
self.s_head = None
|
| 134 |
+
else:
|
| 135 |
+
self.s_head = GroupedMLP(
|
| 136 |
+
n_var=self.num_cols,
|
| 137 |
+
n_in=self.hidden_size,
|
| 138 |
+
n_out=1,
|
| 139 |
+
middle_size=self.middle_size,
|
| 140 |
+
)
|
| 141 |
+
self.s_param = None
|
| 142 |
+
|
| 143 |
+
def init_weights(self, std: float = 0.02):
|
| 144 |
+
for head in self.heads:
|
| 145 |
+
for module in head.modules():
|
| 146 |
+
if isinstance(module, nn.Linear):
|
| 147 |
+
nn.init.normal_(module.weight, std=std)
|
| 148 |
+
if module.bias is not None:
|
| 149 |
+
nn.init.zeros_(module.bias)
|
| 150 |
+
|
| 151 |
+
if self.homoscedastic:
|
| 152 |
+
nn.init.zeros_(self.s_param)
|
| 153 |
+
else:
|
| 154 |
+
self.s_head.init_weights(std=0.0)
|
| 155 |
+
|
| 156 |
+
def _check_input(self, x_cat_tokens: torch.Tensor) -> Tuple[int, int, int]:
|
| 157 |
+
if x_cat_tokens.dim() != 3:
|
| 158 |
+
raise ValueError(f"x_cat_tokens must be [B,M,H], got {tuple(x_cat_tokens.shape)}")
|
| 159 |
+
B, M, H = x_cat_tokens.shape
|
| 160 |
+
if H != self.hidden_size:
|
| 161 |
+
raise ValueError(f"hidden_size mismatch: got {H}, expected {self.hidden_size}")
|
| 162 |
+
if M != self.num_cols:
|
| 163 |
+
raise ValueError(f"categorical token count mismatch: got M={M}, expected {self.num_cols}")
|
| 164 |
+
return B, M, H
|
| 165 |
+
|
| 166 |
+
@torch.no_grad()
|
| 167 |
+
def _build_valid_mask(self, device: torch.device) -> torch.Tensor:
|
| 168 |
+
"""
|
| 169 |
+
valid_mask[m, j] = True iff j < num_classes[m]
|
| 170 |
+
"""
|
| 171 |
+
M = self.num_cols
|
| 172 |
+
cmax = int(self.num_classes.max().item())
|
| 173 |
+
ar = torch.arange(cmax, device=device).view(1, cmax).expand(M, cmax)
|
| 174 |
+
nc = self.num_classes.view(M, 1).expand(M, cmax)
|
| 175 |
+
return ar < nc
|
| 176 |
+
|
| 177 |
+
def forward(
|
| 178 |
+
self,
|
| 179 |
+
x_cat_tokens: torch.Tensor,
|
| 180 |
+
return_padded: bool = False,
|
| 181 |
+
pad_value: Optional[float] = None,
|
| 182 |
+
) -> Union[
|
| 183 |
+
Tuple[List[torch.Tensor], torch.Tensor],
|
| 184 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
| 185 |
+
]:
|
| 186 |
+
"""
|
| 187 |
+
Args:
|
| 188 |
+
x_cat_tokens: [B, M, H]
|
| 189 |
+
B = batch size
|
| 190 |
+
M = number of categorical columns
|
| 191 |
+
H = hidden size (per-column token embedding dim)
|
| 192 |
+
|
| 193 |
+
return_padded:
|
| 194 |
+
False:
|
| 195 |
+
return (logits_list, s)
|
| 196 |
+
True:
|
| 197 |
+
return (logits_padded, s, valid_mask)
|
| 198 |
+
|
| 199 |
+
pad_value:
|
| 200 |
+
Value used to fill invalid class positions in padded logits.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
|
| 204 |
+
Case 1 (return_padded=False):
|
| 205 |
+
logits_list: List length M
|
| 206 |
+
logits_list[m]: [B, C_m]
|
| 207 |
+
s: [B, M]
|
| 208 |
+
s[b, m] = log sigma^2 for sample b, column m
|
| 209 |
+
|
| 210 |
+
Case 2 (return_padded=True):
|
| 211 |
+
logits_padded: [B, M, Cmax]
|
| 212 |
+
s: [B, M]
|
| 213 |
+
valid_mask: [M, Cmax]
|
| 214 |
+
"""
|
| 215 |
+
|
| 216 |
+
# --------------------------------------------------------
|
| 217 |
+
# 1) Basic shape validation
|
| 218 |
+
# --------------------------------------------------------
|
| 219 |
+
# Ensures x_cat_tokens is [B,M,H] and matches decoder config
|
| 220 |
+
B, M, _ = self._check_input(x_cat_tokens)
|
| 221 |
+
|
| 222 |
+
# --------------------------------------------------------
|
| 223 |
+
# 2) Per-column categorical logits
|
| 224 |
+
# --------------------------------------------------------
|
| 225 |
+
# We still use per-column heads because each column
|
| 226 |
+
# can have a different number of classes C_m.
|
| 227 |
+
#
|
| 228 |
+
# logits_list[m] shape: [B, C_m]
|
| 229 |
+
logits_list: List[torch.Tensor] = []
|
| 230 |
+
for m in range(M):
|
| 231 |
+
# x_cat_tokens[:, m, :] -> [B,H]
|
| 232 |
+
# heads[m] maps H -> C_m
|
| 233 |
+
logits_m = self.heads[m](x_cat_tokens[:, m, :])
|
| 234 |
+
logits_list.append(logits_m)
|
| 235 |
+
|
| 236 |
+
# --------------------------------------------------------
|
| 237 |
+
# 3) Sample-wise & column-wise uncertainty (log-variance)
|
| 238 |
+
# --------------------------------------------------------
|
| 239 |
+
# s_head processes all columns at once (grouped, no loop)
|
| 240 |
+
#
|
| 241 |
+
# Input: [B,M,H]
|
| 242 |
+
# Output: [B,M]
|
| 243 |
+
#
|
| 244 |
+
# s[b,m] = log(sigma_{b,m}^2)
|
| 245 |
+
if self.homoscedastic:
|
| 246 |
+
s = self.s_param.unsqueeze(0).expand(B, -1)
|
| 247 |
+
else:
|
| 248 |
+
s = self.s_head(x_cat_tokens).squeeze(-1)
|
| 249 |
+
|
| 250 |
+
# --------------------------------------------------------
|
| 251 |
+
# 4) If no padded output requested
|
| 252 |
+
# --------------------------------------------------------
|
| 253 |
+
if not return_padded:
|
| 254 |
+
# Return:
|
| 255 |
+
# logits_list: List of length M
|
| 256 |
+
# s: [B,M]
|
| 257 |
+
return logits_list, s
|
| 258 |
+
|
| 259 |
+
# --------------------------------------------------------
|
| 260 |
+
# 5) Build padded logits tensor
|
| 261 |
+
# --------------------------------------------------------
|
| 262 |
+
# We unify different C_m into a common Cmax.
|
| 263 |
+
#
|
| 264 |
+
# logits_padded shape: [B,M,Cmax]
|
| 265 |
+
cmax = int(self.num_classes.max().item())
|
| 266 |
+
|
| 267 |
+
if pad_value is None:
|
| 268 |
+
pad_value = torch.finfo(x_cat_tokens.dtype).min
|
| 269 |
+
logits_padded = torch.full(
|
| 270 |
+
(B, M, cmax),
|
| 271 |
+
pad_value,
|
| 272 |
+
device=x_cat_tokens.device,
|
| 273 |
+
dtype=x_cat_tokens.dtype,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
# Fill valid class positions per column
|
| 277 |
+
for m in range(M):
|
| 278 |
+
cm = logits_list[m].size(-1) # C_m
|
| 279 |
+
logits_padded[:, m, :cm] = logits_list[m]
|
| 280 |
+
|
| 281 |
+
# --------------------------------------------------------
|
| 282 |
+
# 6) Build validity mask
|
| 283 |
+
# --------------------------------------------------------
|
| 284 |
+
# valid_mask[m,j] = True if j < C_m
|
| 285 |
+
# = False otherwise
|
| 286 |
+
#
|
| 287 |
+
# Shape: [M, Cmax]
|
| 288 |
+
valid_class_mask = self._build_valid_mask(device=x_cat_tokens.device)
|
| 289 |
+
|
| 290 |
+
# --------------------------------------------------------
|
| 291 |
+
# 7) Return padded outputs
|
| 292 |
+
# --------------------------------------------------------
|
| 293 |
+
return logits_padded, s, valid_class_mask
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# ============================================================
|
| 297 |
+
# DEMO
|
| 298 |
+
# ============================================================
|
| 299 |
+
|
| 300 |
+
def _demo_main():
|
| 301 |
+
import argparse
|
| 302 |
+
|
| 303 |
+
parser = argparse.ArgumentParser()
|
| 304 |
+
parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json")
|
| 305 |
+
parser.add_argument("--hidden_size", type=int, default=768)
|
| 306 |
+
parser.add_argument("--middle_size", type=int, default=None)
|
| 307 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 308 |
+
parser.add_argument("--device", type=str, default=None)
|
| 309 |
+
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
|
| 310 |
+
args = parser.parse_args()
|
| 311 |
+
|
| 312 |
+
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 313 |
+
dtype_map = {
|
| 314 |
+
"float16": torch.float16,
|
| 315 |
+
"bfloat16": torch.bfloat16,
|
| 316 |
+
"float32": torch.float32,
|
| 317 |
+
}
|
| 318 |
+
dtype = dtype_map[args.dtype]
|
| 319 |
+
|
| 320 |
+
# --------------------------------------------------------
|
| 321 |
+
# Load vocab spec
|
| 322 |
+
# --------------------------------------------------------
|
| 323 |
+
spec = load_json(args.cat_vocab_json)
|
| 324 |
+
items = sorted(spec.items(), key=lambda x_: x_[1]["col_id"])
|
| 325 |
+
|
| 326 |
+
M = len(items)
|
| 327 |
+
B = args.batch_size
|
| 328 |
+
H = args.hidden_size
|
| 329 |
+
|
| 330 |
+
num_classes = [int(s["num_classes"]) for _, s in items]
|
| 331 |
+
|
| 332 |
+
print("===== Categorical Columns =====")
|
| 333 |
+
for i, (name, s) in enumerate(items):
|
| 334 |
+
print(f"{i:03d} {name:20s} classes={s['num_classes']}")
|
| 335 |
+
print()
|
| 336 |
+
|
| 337 |
+
# --------------------------------------------------------
|
| 338 |
+
# Build model
|
| 339 |
+
# --------------------------------------------------------
|
| 340 |
+
model = CategoricalDecoder(
|
| 341 |
+
hidden_size=args.hidden_size,
|
| 342 |
+
cat_vocab_json=args.cat_vocab_json,
|
| 343 |
+
middle_size=args.middle_size,
|
| 344 |
+
).to(device=device, dtype=dtype)
|
| 345 |
+
|
| 346 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 347 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 348 |
+
|
| 349 |
+
print(f"Model parameters: {total_params:,} (trainable: {trainable_params:,})")
|
| 350 |
+
print()
|
| 351 |
+
|
| 352 |
+
# --------------------------------------------------------
|
| 353 |
+
# Fake input tokens
|
| 354 |
+
# --------------------------------------------------------
|
| 355 |
+
x = torch.randn(B, M, H, device=device, dtype=dtype)
|
| 356 |
+
|
| 357 |
+
print("Input tokens shape:", tuple(x.shape))
|
| 358 |
+
print()
|
| 359 |
+
|
| 360 |
+
# --------------------------------------------------------
|
| 361 |
+
# Case 1: logits_list
|
| 362 |
+
# --------------------------------------------------------
|
| 363 |
+
print("===== Forward: logits_list mode =====")
|
| 364 |
+
|
| 365 |
+
with torch.no_grad():
|
| 366 |
+
logits_list, s = model(x, return_padded=False)
|
| 367 |
+
|
| 368 |
+
for m, (name, spec_item) in enumerate(items):
|
| 369 |
+
C = spec_item["num_classes"]
|
| 370 |
+
print(f"{m:03d} {name:20s} logits:", tuple(logits_list[m].shape), f"(expected {(B, C)})")
|
| 371 |
+
|
| 372 |
+
print("s shape:", tuple(s.shape))
|
| 373 |
+
print()
|
| 374 |
+
|
| 375 |
+
# --------------------------------------------------------
|
| 376 |
+
# Case 2: padded logits
|
| 377 |
+
# --------------------------------------------------------
|
| 378 |
+
print("===== Forward: padded mode =====")
|
| 379 |
+
|
| 380 |
+
with torch.no_grad():
|
| 381 |
+
logits_padded, s2, valid_mask = model(x, return_padded=True)
|
| 382 |
+
|
| 383 |
+
print("logits_padded:", tuple(logits_padded.shape))
|
| 384 |
+
print("s:", tuple(s2.shape))
|
| 385 |
+
print("valid_mask:", tuple(valid_mask.shape))
|
| 386 |
+
print()
|
| 387 |
+
|
| 388 |
+
# --------------------------------------------------------
|
| 389 |
+
# Visualize valid mask
|
| 390 |
+
# --------------------------------------------------------
|
| 391 |
+
print("===== Valid class mask (first 10 columns) =====")
|
| 392 |
+
|
| 393 |
+
cols_to_show = min(10, M)
|
| 394 |
+
for m in range(cols_to_show):
|
| 395 |
+
cm = num_classes[m]
|
| 396 |
+
valid = valid_mask[m].sum().item()
|
| 397 |
+
print(f"col {m:02d} num_classes={cm} valid_mask_sum={valid}")
|
| 398 |
+
|
| 399 |
+
print()
|
| 400 |
+
|
| 401 |
+
# --------------------------------------------------------
|
| 402 |
+
# Check padded logits correctness
|
| 403 |
+
# --------------------------------------------------------
|
| 404 |
+
print("===== Padded logits sanity check =====")
|
| 405 |
+
|
| 406 |
+
for m in range(cols_to_show):
|
| 407 |
+
cm = num_classes[m]
|
| 408 |
+
|
| 409 |
+
valid_region = logits_padded[:, m, :cm]
|
| 410 |
+
padded_region = logits_padded[:, m, cm:]
|
| 411 |
+
|
| 412 |
+
print(f"col {m:02d} valid region shape:", tuple(valid_region.shape))
|
| 413 |
+
|
| 414 |
+
if padded_region.numel() > 0:
|
| 415 |
+
print(f"col {m:02d} padded region mean:", padded_region.mean().item())
|
| 416 |
+
|
| 417 |
+
print()
|
| 418 |
+
|
| 419 |
+
print("Demo finished successfully.")
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
if __name__ == "__main__":
|
| 423 |
+
_demo_main()
|
modelling/decode_numeric.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# decode_numeric.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Numeric decoder module for tabular transformer.
|
| 6 |
+
|
| 7 |
+
Symmetric to embed_numeric.py (bucketed by n_in):
|
| 8 |
+
- For each bucket (same n_in), we decode tokens without a Python for-loop over columns.
|
| 9 |
+
- Uses a batched per-variable MLP with per-column parameters (NOT shared across V).
|
| 10 |
+
|
| 11 |
+
Input:
|
| 12 |
+
x_tokens: [B, total_numeric_tokens, H]
|
| 13 |
+
token order must match numeric_vocab.json:
|
| 14 |
+
groups by n_in ascending, within group by feature name,
|
| 15 |
+
and within each feature: n_in tokens.
|
| 16 |
+
|
| 17 |
+
Output:
|
| 18 |
+
values_by_nin: Dict[int, Tensor]
|
| 19 |
+
n_in -> x_hat [B, V, n_in]
|
| 20 |
+
|
| 21 |
+
middle_size:
|
| 22 |
+
- None: 1-layer per-variable Linear
|
| 23 |
+
- int : 2-layer per-variable MLP (Linear -> GELU -> Linear)
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from typing import Dict, List, Optional
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn as nn
|
| 30 |
+
|
| 31 |
+
from utils import GroupedMLP, load_json
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class NumericDecoder(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Decode numeric tokens back to numeric values, bucketed by n_in.
|
| 37 |
+
|
| 38 |
+
Input:
|
| 39 |
+
x_tokens: [B, total_numeric_tokens, H]
|
| 40 |
+
|
| 41 |
+
Output:
|
| 42 |
+
values_by_nin:
|
| 43 |
+
n_in -> y_hat [B, V, n_in]
|
| 44 |
+
|
| 45 |
+
s_by_nin:
|
| 46 |
+
n_in -> s [B, V]
|
| 47 |
+
where s = log(sigma^2), shared across the n_in dimensions
|
| 48 |
+
of each variable, intended for heteroscedastic loss computation.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(
|
| 52 |
+
self,
|
| 53 |
+
hidden_size: int,
|
| 54 |
+
numeric_vocab_json: str,
|
| 55 |
+
middle_size: Optional[int] = None,
|
| 56 |
+
homoscedastic: bool = True,
|
| 57 |
+
):
|
| 58 |
+
super().__init__()
|
| 59 |
+
self.hidden_size = int(hidden_size)
|
| 60 |
+
self.middle_size = None if middle_size is None else int(middle_size)
|
| 61 |
+
self.homoscedastic = bool(homoscedastic)
|
| 62 |
+
|
| 63 |
+
spec = load_json(numeric_vocab_json)
|
| 64 |
+
self.groups: List[Dict] = list(spec["groups"])
|
| 65 |
+
self.total_numeric_tokens = int(spec["total_numeric_tokens"])
|
| 66 |
+
self.group_token_offsets: Dict[str, int] = dict(spec.get("group_token_offsets", {}))
|
| 67 |
+
|
| 68 |
+
self.group_v_decoders = nn.ModuleList()
|
| 69 |
+
self.group_s_decoders = nn.ModuleList()
|
| 70 |
+
self.group_nins: List[int] = []
|
| 71 |
+
self.group_Vs: List[int] = []
|
| 72 |
+
|
| 73 |
+
for g in self.groups:
|
| 74 |
+
n_in = int(g["n_in"])
|
| 75 |
+
names = list(g["feature_names"])
|
| 76 |
+
V = len(names)
|
| 77 |
+
|
| 78 |
+
self.group_nins.append(n_in) # noqa
|
| 79 |
+
self.group_Vs.append(V)
|
| 80 |
+
|
| 81 |
+
# value decoder: [B,V,n_in*H] -> [B,V,n_in]
|
| 82 |
+
self.group_v_decoders.append(
|
| 83 |
+
GroupedMLP(
|
| 84 |
+
n_var=V,
|
| 85 |
+
n_in=n_in * self.hidden_size,
|
| 86 |
+
n_out=n_in,
|
| 87 |
+
middle_size=self.middle_size,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# uncertainty decoder: [B,V,H] -> [B,V,1] -> [B,V]
|
| 92 |
+
if not self.homoscedastic:
|
| 93 |
+
self.group_s_decoders.append(
|
| 94 |
+
GroupedMLP(
|
| 95 |
+
n_var=V,
|
| 96 |
+
n_in=self.hidden_size,
|
| 97 |
+
n_out=1,
|
| 98 |
+
middle_size=self.middle_size,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
if self.homoscedastic:
|
| 103 |
+
self.group_s_params = nn.ParameterList(
|
| 104 |
+
[nn.Parameter(torch.zeros(V)) for V in self.group_Vs]
|
| 105 |
+
)
|
| 106 |
+
else:
|
| 107 |
+
self.group_s_params = None
|
| 108 |
+
|
| 109 |
+
# spec integrity check
|
| 110 |
+
running = 0
|
| 111 |
+
for g in self.groups:
|
| 112 |
+
n_in = int(g["n_in"])
|
| 113 |
+
V = len(g["feature_names"])
|
| 114 |
+
key = str(n_in)
|
| 115 |
+
|
| 116 |
+
if key not in self.group_token_offsets:
|
| 117 |
+
raise ValueError(f"Missing group_token_offsets entry for n_in={n_in}")
|
| 118 |
+
if int(self.group_token_offsets[key]) != running:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"group_token_offsets[{key}]={self.group_token_offsets[key]} does not match expected {running}"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
running += V * n_in
|
| 124 |
+
|
| 125 |
+
if running != self.total_numeric_tokens:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"total_numeric_tokens={self.total_numeric_tokens} does not match expected {running}"
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def init_weights(self, std: float = 0.02):
|
| 131 |
+
for dec in self.group_v_decoders:
|
| 132 |
+
dec.init_weights(std=std)
|
| 133 |
+
|
| 134 |
+
if self.homoscedastic:
|
| 135 |
+
for p in self.group_s_params:
|
| 136 |
+
nn.init.zeros_(p)
|
| 137 |
+
else:
|
| 138 |
+
for dec in self.group_s_decoders:
|
| 139 |
+
dec.init_weights(std=0.0)
|
| 140 |
+
|
| 141 |
+
def forward(self, x_tokens: torch.Tensor):
|
| 142 |
+
if x_tokens.dim() != 3:
|
| 143 |
+
raise ValueError(f"x_tokens must be [B,T,H], got {tuple(x_tokens.shape)}")
|
| 144 |
+
|
| 145 |
+
B, T, H = x_tokens.shape
|
| 146 |
+
if H != self.hidden_size:
|
| 147 |
+
raise ValueError(f"hidden_size mismatch: got H={H}, expected {self.hidden_size}")
|
| 148 |
+
if T != self.total_numeric_tokens:
|
| 149 |
+
raise ValueError(f"token length mismatch: got T={T}, expected {self.total_numeric_tokens}")
|
| 150 |
+
|
| 151 |
+
value_out: Dict[int, torch.Tensor] = {}
|
| 152 |
+
s_out: Dict[int, torch.Tensor] = {}
|
| 153 |
+
|
| 154 |
+
for gi, n_in in enumerate(self.group_nins):
|
| 155 |
+
key = str(n_in)
|
| 156 |
+
start = int(self.group_token_offsets[key])
|
| 157 |
+
|
| 158 |
+
V = self.group_Vs[gi]
|
| 159 |
+
length = V * n_in
|
| 160 |
+
|
| 161 |
+
xg_tok = x_tokens[:, start:start + length, :] # [B, V*n_in, H]
|
| 162 |
+
xg_tok4 = xg_tok.reshape(B, V, n_in, H) # [B, V, n_in, H]
|
| 163 |
+
xg_flat = xg_tok4.reshape(B, V, n_in * H) # [B, V, n_in*H]
|
| 164 |
+
|
| 165 |
+
# values: [B, V, n_in]
|
| 166 |
+
y = self.group_v_decoders[gi](xg_flat)
|
| 167 |
+
|
| 168 |
+
# s = log sigma^2: [B, V]
|
| 169 |
+
if self.homoscedastic:
|
| 170 |
+
s = self.group_s_params[gi].unsqueeze(0).expand(B, -1)
|
| 171 |
+
else:
|
| 172 |
+
x_var = xg_tok4.mean(dim=2) # [B, V, H]
|
| 173 |
+
s = self.group_s_decoders[gi](x_var).squeeze(-1) # [B, V]
|
| 174 |
+
|
| 175 |
+
value_out[n_in] = y
|
| 176 |
+
s_out[n_in] = s
|
| 177 |
+
|
| 178 |
+
return value_out, s_out
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ============================================================
|
| 182 |
+
# DEMO
|
| 183 |
+
# ============================================================
|
| 184 |
+
|
| 185 |
+
def _demo_main():
|
| 186 |
+
import argparse
|
| 187 |
+
|
| 188 |
+
parser = argparse.ArgumentParser()
|
| 189 |
+
parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
|
| 190 |
+
parser.add_argument("--hidden_size", type=int, default=768)
|
| 191 |
+
parser.add_argument("--middle_size", type=int, default=-1,
|
| 192 |
+
help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
|
| 193 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 194 |
+
parser.add_argument("--device", type=str, default=None)
|
| 195 |
+
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
|
| 196 |
+
args = parser.parse_args()
|
| 197 |
+
|
| 198 |
+
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 199 |
+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
| 200 |
+
dtype = dtype_map[args.dtype]
|
| 201 |
+
|
| 202 |
+
# Directly load existing numeric vocab spec
|
| 203 |
+
spec = load_json(args.numeric_vocab_json)
|
| 204 |
+
print(f"Loaded numeric vocab spec from: {args.numeric_vocab_json}")
|
| 205 |
+
print(f"Groups (n_in -> V):", {int(g['n_in']): len(g['feature_names']) for g in spec["groups"]})
|
| 206 |
+
print("total_numeric_tokens:", spec["total_numeric_tokens"])
|
| 207 |
+
print("group_token_offsets:", spec["group_token_offsets"])
|
| 208 |
+
|
| 209 |
+
middle_size = None if args.middle_size < 0 else int(args.middle_size)
|
| 210 |
+
model = NumericDecoder(
|
| 211 |
+
hidden_size=args.hidden_size,
|
| 212 |
+
numeric_vocab_json=args.numeric_vocab_json,
|
| 213 |
+
middle_size=middle_size,
|
| 214 |
+
).to(device=device, dtype=dtype)
|
| 215 |
+
model.eval()
|
| 216 |
+
|
| 217 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 218 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 219 |
+
print(f"Total parameters (NumericDecoder): {total_params:,} (trainable: {trainable_params:,})")
|
| 220 |
+
|
| 221 |
+
B = args.batch_size
|
| 222 |
+
T = int(spec["total_numeric_tokens"])
|
| 223 |
+
H = args.hidden_size
|
| 224 |
+
|
| 225 |
+
x_tokens = torch.randn(B, T, H, device=device, dtype=dtype)
|
| 226 |
+
|
| 227 |
+
with torch.no_grad():
|
| 228 |
+
values_by_nin, s_by_nin = model(x_tokens)
|
| 229 |
+
|
| 230 |
+
print("Input tokens:", tuple(x_tokens.shape), x_tokens.dtype, x_tokens.device)
|
| 231 |
+
print("Decoded values:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
|
| 232 |
+
print("Decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
|
| 233 |
+
# values_by_nin[n_in]: [B, V, n_in]
|
| 234 |
+
# s_by_nin[n_in]: [B, V]
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
_demo_main()
|
modelling/embed_categorical.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# embed_categorical.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Categorical embedding module for tabular transformer.
|
| 6 |
+
|
| 7 |
+
Design:
|
| 8 |
+
- Each categorical column = 1 token
|
| 9 |
+
- Value embedding: ONE global lookup table using (offset + local_id)
|
| 10 |
+
- ID embedding: ONE categorical column-ID embedding table
|
| 11 |
+
- Explicit col_id stored in cat_vocab.json (no implicit ordering assumptions)
|
| 12 |
+
|
| 13 |
+
Outputs:
|
| 14 |
+
local_ids [B,M] -> tokens [B,M,H]
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from typing import Dict, List, Optional, Tuple
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
from utils import load_json, save_json
|
| 24 |
+
|
| 25 |
+
SPECIAL_MASK = "__MASK__"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ============================================================
|
| 29 |
+
# Meta → categorical column list
|
| 30 |
+
# ============================================================
|
| 31 |
+
|
| 32 |
+
def get_categorical_feature_names_from_meta(tabular_meta: Dict) -> List[str]:
|
| 33 |
+
"""
|
| 34 |
+
Deterministic ordering:
|
| 35 |
+
alphabetical by feature name.
|
| 36 |
+
"""
|
| 37 |
+
cols = []
|
| 38 |
+
for k, v in tabular_meta.items():
|
| 39 |
+
if v.get("dataclass") == "categorical" and not v.get("is_array_valued", False):
|
| 40 |
+
cols.append(k)
|
| 41 |
+
return sorted(cols)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ============================================================
|
| 45 |
+
# Vocab spec
|
| 46 |
+
# ============================================================
|
| 47 |
+
|
| 48 |
+
@dataclass
|
| 49 |
+
class CatColSpec:
|
| 50 |
+
name: str
|
| 51 |
+
col_id: int
|
| 52 |
+
offset: int
|
| 53 |
+
num_classes: int
|
| 54 |
+
mask_local_id: int
|
| 55 |
+
label2id: Dict[str, int]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_cat_vocab_spec_from_meta(
|
| 59 |
+
tabular_meta: Dict,
|
| 60 |
+
categorical_feature_names: List[str],
|
| 61 |
+
label_order: str = "alpha",
|
| 62 |
+
) -> Dict[str, CatColSpec]:
|
| 63 |
+
vocab: Dict[str, CatColSpec] = {}
|
| 64 |
+
|
| 65 |
+
offset = 0
|
| 66 |
+
for j, col in enumerate(categorical_feature_names):
|
| 67 |
+
info = tabular_meta[col]
|
| 68 |
+
class_stats = info.get("class_stats", {}) or {}
|
| 69 |
+
|
| 70 |
+
# deterministic label order
|
| 71 |
+
if label_order == "alpha":
|
| 72 |
+
labels = sorted(class_stats.keys())
|
| 73 |
+
elif label_order == "freq_desc":
|
| 74 |
+
labels = sorted(class_stats.keys(), key=lambda k: (-class_stats[k], k))
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError("label_order must be alpha or freq_desc")
|
| 77 |
+
|
| 78 |
+
label2id = {lab: i for i, lab in enumerate(labels)}
|
| 79 |
+
|
| 80 |
+
mask_local_id = len(labels)
|
| 81 |
+
label2id[SPECIAL_MASK] = mask_local_id
|
| 82 |
+
|
| 83 |
+
spec = CatColSpec(
|
| 84 |
+
name=col,
|
| 85 |
+
col_id=j, # EXPLICIT categorical column id
|
| 86 |
+
offset=offset,
|
| 87 |
+
num_classes=mask_local_id + 1,
|
| 88 |
+
mask_local_id=mask_local_id,
|
| 89 |
+
label2id=label2id,
|
| 90 |
+
)
|
| 91 |
+
vocab[col] = spec
|
| 92 |
+
|
| 93 |
+
offset += spec.num_classes
|
| 94 |
+
|
| 95 |
+
return vocab
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def save_cat_vocab_json(vocab: Dict[str, CatColSpec], path: str) -> None:
|
| 99 |
+
out = {}
|
| 100 |
+
|
| 101 |
+
for col, spec in vocab.items():
|
| 102 |
+
out[col] = {
|
| 103 |
+
"col_id": spec.col_id,
|
| 104 |
+
"offset": spec.offset,
|
| 105 |
+
"num_classes": spec.num_classes,
|
| 106 |
+
"mask_local_id": spec.mask_local_id,
|
| 107 |
+
"global_id_start": spec.offset,
|
| 108 |
+
"global_id_end": spec.offset + spec.num_classes - 1,
|
| 109 |
+
"label2id": spec.label2id,
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
save_json(out, path)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ============================================================
|
| 116 |
+
# Embedding modules
|
| 117 |
+
# ============================================================
|
| 118 |
+
|
| 119 |
+
class CategoricalValueEmbedding(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
Global value embedding using offsets.
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, hidden_size: int, cat_vocab_json: str):
|
| 125 |
+
super().__init__()
|
| 126 |
+
|
| 127 |
+
spec = load_json(cat_vocab_json)
|
| 128 |
+
|
| 129 |
+
# sort by col_id to ensure consistent tensor layout
|
| 130 |
+
items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
|
| 131 |
+
|
| 132 |
+
offsets = []
|
| 133 |
+
num_classes = []
|
| 134 |
+
col_ids = []
|
| 135 |
+
|
| 136 |
+
total_vocab = 0
|
| 137 |
+
|
| 138 |
+
for name, s in items:
|
| 139 |
+
offsets.append(int(s["offset"]))
|
| 140 |
+
num_classes.append(int(s["num_classes"]))
|
| 141 |
+
col_ids.append(int(s["col_id"]))
|
| 142 |
+
total_vocab = max(total_vocab, s["offset"] + s["num_classes"])
|
| 143 |
+
|
| 144 |
+
self.hidden_size = int(hidden_size)
|
| 145 |
+
self.total_vocab_size = int(total_vocab)
|
| 146 |
+
# Merge all classes to avoid many small nn.Embedding modules
|
| 147 |
+
self.emb = nn.Embedding(self.total_vocab_size, self.hidden_size)
|
| 148 |
+
|
| 149 |
+
self.register_buffer("offsets", torch.tensor(offsets, dtype=torch.long), persistent=True)
|
| 150 |
+
self.register_buffer("num_classes", torch.tensor(num_classes, dtype=torch.long), persistent=True)
|
| 151 |
+
self.register_buffer("col_ids", torch.tensor(col_ids, dtype=torch.long), persistent=True)
|
| 152 |
+
|
| 153 |
+
def init_weights(self, std=0.02):
|
| 154 |
+
nn.init.normal_(self.emb.weight, std=std)
|
| 155 |
+
|
| 156 |
+
def forward(self, local_ids: torch.LongTensor) -> torch.Tensor:
|
| 157 |
+
"""
|
| 158 |
+
local_ids: [B,M]
|
| 159 |
+
returns: [B,M,H]
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
if local_ids.dim() != 2:
|
| 163 |
+
raise ValueError("local_ids must be [B,M]")
|
| 164 |
+
|
| 165 |
+
B, M = local_ids.shape
|
| 166 |
+
|
| 167 |
+
if M != self.offsets.numel():
|
| 168 |
+
raise ValueError("Column count mismatch")
|
| 169 |
+
|
| 170 |
+
if torch.any(local_ids < 0):
|
| 171 |
+
raise ValueError("Negative local_id")
|
| 172 |
+
|
| 173 |
+
nc = self.num_classes.view(1, M).expand(B, M)
|
| 174 |
+
if torch.any(local_ids >= nc):
|
| 175 |
+
raise ValueError("local_ids out of range")
|
| 176 |
+
|
| 177 |
+
gid = self.offsets.view(1, M) + local_ids
|
| 178 |
+
return self.emb(gid)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class CategoricalIdEmbedding(nn.Module):
|
| 182 |
+
"""
|
| 183 |
+
Explicit categorical column ID embedding.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, hidden_size: int, cat_vocab_json: str):
|
| 187 |
+
super().__init__()
|
| 188 |
+
|
| 189 |
+
spec = load_json(cat_vocab_json)
|
| 190 |
+
items = sorted(spec.items(), key=lambda x: x[1]["col_id"])
|
| 191 |
+
|
| 192 |
+
col_ids = [s["col_id"] for _, s in items]
|
| 193 |
+
max_col_id = max(col_ids)
|
| 194 |
+
|
| 195 |
+
self.emb = nn.Embedding(max_col_id + 1, hidden_size)
|
| 196 |
+
|
| 197 |
+
self.register_buffer(
|
| 198 |
+
"cat_col_ids",
|
| 199 |
+
torch.tensor(col_ids, dtype=torch.long),
|
| 200 |
+
persistent=True,
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
self.hidden_size = hidden_size
|
| 204 |
+
|
| 205 |
+
def init_weights(self, std=0.02):
|
| 206 |
+
nn.init.normal_(self.emb.weight, std=std)
|
| 207 |
+
|
| 208 |
+
def forward(self, batch_size: int) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
returns [B,M,H]
|
| 211 |
+
"""
|
| 212 |
+
id_vec = self.emb(self.cat_col_ids) # [M,H]
|
| 213 |
+
return id_vec.view(1, -1, self.hidden_size).expand(batch_size, -1, -1)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
class CategoricalEmbedding(nn.Module):
|
| 217 |
+
"""
|
| 218 |
+
token = value_embedding + categorical_id_embedding
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def __init__(self, hidden_size: int, cat_vocab_json: str):
|
| 222 |
+
super().__init__()
|
| 223 |
+
|
| 224 |
+
self.value_emb = CategoricalValueEmbedding(hidden_size, cat_vocab_json)
|
| 225 |
+
self.id_emb = CategoricalIdEmbedding(hidden_size, cat_vocab_json)
|
| 226 |
+
|
| 227 |
+
def init_weights(self, std=0.02):
|
| 228 |
+
self.value_emb.init_weights(std=std)
|
| 229 |
+
self.id_emb.init_weights(std=std)
|
| 230 |
+
|
| 231 |
+
def forward(
|
| 232 |
+
self,
|
| 233 |
+
local_ids: torch.LongTensor, # [B, M]
|
| 234 |
+
valid_positions: Optional[torch.Tensor] = None, # Bool [B,M] (True=valid) or indices [K,2]
|
| 235 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 236 |
+
"""
|
| 237 |
+
Returns:
|
| 238 |
+
tokens: [B, M, H]
|
| 239 |
+
token_mask: [B, M] (1=valid, 0=invalid)
|
| 240 |
+
"""
|
| 241 |
+
if local_ids.dim() != 2:
|
| 242 |
+
raise ValueError(f"local_ids must be [B,M], got {tuple(local_ids.shape)}")
|
| 243 |
+
B, M = local_ids.shape
|
| 244 |
+
|
| 245 |
+
tokens = self.value_emb(local_ids) + self.id_emb(B) # [B,M,H]
|
| 246 |
+
|
| 247 |
+
# Default: all tokens are valid
|
| 248 |
+
valid = torch.ones((B, M), dtype=torch.bool, device=local_ids.device)
|
| 249 |
+
|
| 250 |
+
if valid_positions is not None:
|
| 251 |
+
if valid_positions.dtype == torch.bool:
|
| 252 |
+
if valid_positions.shape != (B, M):
|
| 253 |
+
raise ValueError(
|
| 254 |
+
f"valid_positions (bool) must be [B,M]=({B}, {M}), got {tuple(valid_positions.shape)}")
|
| 255 |
+
valid = valid_positions.to(device=local_ids.device)
|
| 256 |
+
else:
|
| 257 |
+
# Optional: support index pairs [K,2] where each row is (b_idx, m_idx) for valid positions
|
| 258 |
+
if valid_positions.dim() != 2 or valid_positions.size(1) != 2:
|
| 259 |
+
raise ValueError("valid_positions (indices) must be [K,2] with (batch_idx, col_idx)")
|
| 260 |
+
valid = torch.zeros((B, M), dtype=torch.bool, device=local_ids.device)
|
| 261 |
+
b_idx = valid_positions[:, 0].to(device=local_ids.device, dtype=torch.long)
|
| 262 |
+
m_idx = valid_positions[:, 1].to(device=local_ids.device, dtype=torch.long)
|
| 263 |
+
valid[b_idx, m_idx] = True
|
| 264 |
+
|
| 265 |
+
# Token mask: 1=valid, 0=invalid
|
| 266 |
+
token_mask = valid.to(dtype=torch.long) # [B,M]
|
| 267 |
+
|
| 268 |
+
# This is WRONG: we should allow __MASK__ to attend other columns
|
| 269 |
+
# # Invalid tokens must not contribute
|
| 270 |
+
# invalid = ~valid
|
| 271 |
+
# if invalid.any():
|
| 272 |
+
# tokens = tokens.masked_fill(invalid.unsqueeze(-1), 0.0)
|
| 273 |
+
|
| 274 |
+
return tokens, token_mask
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
# ============================================================
|
| 278 |
+
# DEMO
|
| 279 |
+
# ============================================================
|
| 280 |
+
|
| 281 |
+
def _demo_main():
|
| 282 |
+
import argparse
|
| 283 |
+
|
| 284 |
+
parser = argparse.ArgumentParser()
|
| 285 |
+
parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json")
|
| 286 |
+
parser.add_argument("--cat_vocab_json", type=str, default="data/cat_vocab.json")
|
| 287 |
+
parser.add_argument("--hidden_size", type=int, default=768)
|
| 288 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 289 |
+
args = parser.parse_args()
|
| 290 |
+
|
| 291 |
+
tabular_meta = load_json(args.tabular_meta)
|
| 292 |
+
|
| 293 |
+
cat_names = get_categorical_feature_names_from_meta(tabular_meta)
|
| 294 |
+
print(f"Found {len(cat_names)} categorical columns")
|
| 295 |
+
|
| 296 |
+
vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names)
|
| 297 |
+
save_cat_vocab_json(vocab, args.cat_vocab_json)
|
| 298 |
+
print(f"Saved vocab to {args.cat_vocab_json}")
|
| 299 |
+
|
| 300 |
+
model = CategoricalEmbedding(
|
| 301 |
+
hidden_size=args.hidden_size,
|
| 302 |
+
cat_vocab_json=args.cat_vocab_json,
|
| 303 |
+
)
|
| 304 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 305 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 306 |
+
print(f"Total parameters (CategoricalEmbedding): {total_params:,} (trainable: {trainable_params:,})")
|
| 307 |
+
|
| 308 |
+
B = args.batch_size
|
| 309 |
+
M = len(cat_names)
|
| 310 |
+
|
| 311 |
+
local_ids = torch.zeros((B, M), dtype=torch.long)
|
| 312 |
+
|
| 313 |
+
with torch.no_grad():
|
| 314 |
+
out, mask = model(local_ids)
|
| 315 |
+
|
| 316 |
+
print("local_ids:", tuple(local_ids.shape))
|
| 317 |
+
print("output:", tuple(out.shape)) # [B,M,H]
|
| 318 |
+
print("mask:", tuple(mask.shape)) # [B,M]
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if __name__ == "__main__":
|
| 322 |
+
_demo_main()
|
modelling/embed_numeric.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# embed_numeric.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
"""
|
| 5 |
+
Numeric embedding module for tabular transformer.
|
| 6 |
+
|
| 7 |
+
Updates in this version:
|
| 8 |
+
- numeric_vocab.json now includes:
|
| 9 |
+
- total_numeric_tokens
|
| 10 |
+
- group_token_offsets (by n_in)
|
| 11 |
+
- demo_main prints total parameter count
|
| 12 |
+
|
| 13 |
+
Design:
|
| 14 |
+
- scalar numeric (n_in=1): 1 token
|
| 15 |
+
- vector numeric (n_in=L): L tokens
|
| 16 |
+
- per bucket (same n_in): GroupedMLP with per-column weights (no for-loop over columns)
|
| 17 |
+
input : [B, V, n_in]
|
| 18 |
+
output : [B, V*n_in, H]
|
| 19 |
+
- middle_size:
|
| 20 |
+
- None: 1-layer
|
| 21 |
+
- int : 2-layer (Linear -> GELU -> Linear)
|
| 22 |
+
- NumericIdEmbedding:
|
| 23 |
+
- per numeric column id embedding [H]
|
| 24 |
+
- broadcast across that column's n_in tokens
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from dataclasses import dataclass
|
| 28 |
+
from typing import Dict, List, Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
|
| 33 |
+
from utils import load_json, save_json, GroupedMLP
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ============================================================
|
| 37 |
+
# Meta parsing
|
| 38 |
+
# ============================================================
|
| 39 |
+
|
| 40 |
+
def infer_n_in_from_meta_item(info: Dict) -> int:
|
| 41 |
+
return int(info["array_length"]) if info["is_array_valued"] else 1
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_numeric_feature_names_and_dims_from_meta(tabular_meta: Dict) -> List[Tuple[str, int]]:
|
| 45 |
+
"""
|
| 46 |
+
Return list of (feature_name, n_in) for numeric features.
|
| 47 |
+
|
| 48 |
+
Heuristic:
|
| 49 |
+
- info['dataclass'] == 'numeric' is treated as numeric.
|
| 50 |
+
"""
|
| 51 |
+
out: List[Tuple[str, int]] = []
|
| 52 |
+
for name, info in tabular_meta.items():
|
| 53 |
+
if info.get("dataclass") != "numeric":
|
| 54 |
+
continue
|
| 55 |
+
n_in = infer_n_in_from_meta_item(info)
|
| 56 |
+
out.append((name, n_in))
|
| 57 |
+
# deterministic: group by n_in then name
|
| 58 |
+
out.sort(key=lambda x: (x[1], x[0]))
|
| 59 |
+
return out
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ============================================================
|
| 63 |
+
# Vocab/spec building
|
| 64 |
+
# ============================================================
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class NumColSpec:
|
| 68 |
+
name: str
|
| 69 |
+
col_id: int
|
| 70 |
+
n_in: int
|
| 71 |
+
group_index: int
|
| 72 |
+
index_within_group: int
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_numeric_vocab_spec_from_meta(tabular_meta: Dict) -> Dict:
|
| 76 |
+
"""
|
| 77 |
+
Build numeric_vocab.json dict.
|
| 78 |
+
|
| 79 |
+
Output keys:
|
| 80 |
+
- ordered_feature_names
|
| 81 |
+
- features[name] = {col_id, n_in, group_index, index_within_group}
|
| 82 |
+
- groups = [{n_in, feature_names}, ...] sorted by n_in asc
|
| 83 |
+
- total_numeric_tokens
|
| 84 |
+
- group_token_offsets: { "<n_in>": <start_token_index> }
|
| 85 |
+
token order is groups by n_in asc, within group by feature name
|
| 86 |
+
"""
|
| 87 |
+
feats = get_numeric_feature_names_and_dims_from_meta(tabular_meta)
|
| 88 |
+
if not feats:
|
| 89 |
+
raise ValueError("No numeric features found (dataclass=='numeric').")
|
| 90 |
+
|
| 91 |
+
# group by n_in
|
| 92 |
+
groups_map: Dict[int, List[str]] = {}
|
| 93 |
+
for name, n_in in feats:
|
| 94 |
+
groups_map.setdefault(n_in, []).append(name)
|
| 95 |
+
|
| 96 |
+
for n_in in groups_map:
|
| 97 |
+
groups_map[n_in] = sorted(groups_map[n_in])
|
| 98 |
+
|
| 99 |
+
group_nins = sorted(groups_map.keys())
|
| 100 |
+
|
| 101 |
+
groups: List[Dict] = []
|
| 102 |
+
ordered_feature_names: List[str] = []
|
| 103 |
+
|
| 104 |
+
for n_in in group_nins:
|
| 105 |
+
names = groups_map[n_in]
|
| 106 |
+
groups.append({"n_in": int(n_in), "feature_names": names})
|
| 107 |
+
ordered_feature_names.extend(names)
|
| 108 |
+
|
| 109 |
+
# build per-feature mapping
|
| 110 |
+
name_to_group: Dict[str, Tuple[int, int]] = {}
|
| 111 |
+
for gi, g in enumerate(groups):
|
| 112 |
+
for idx, nm in enumerate(g["feature_names"]):
|
| 113 |
+
name_to_group[nm] = (gi, idx)
|
| 114 |
+
|
| 115 |
+
features: Dict[str, Dict] = {}
|
| 116 |
+
for col_id, nm in enumerate(ordered_feature_names):
|
| 117 |
+
gi, idx = name_to_group[nm]
|
| 118 |
+
n_in = int(groups[gi]["n_in"])
|
| 119 |
+
features[nm] = {
|
| 120 |
+
"col_id": int(col_id),
|
| 121 |
+
"n_in": int(n_in),
|
| 122 |
+
"group_index": int(gi),
|
| 123 |
+
"index_within_group": int(idx),
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# total tokens + group token offsets
|
| 127 |
+
total_numeric_tokens = 0
|
| 128 |
+
group_token_offsets: Dict[str, int] = {}
|
| 129 |
+
running = 0
|
| 130 |
+
for g in groups:
|
| 131 |
+
n_in = int(g["n_in"])
|
| 132 |
+
group_token_offsets[str(n_in)] = int(running)
|
| 133 |
+
V = len(g["feature_names"])
|
| 134 |
+
running += V * n_in
|
| 135 |
+
total_numeric_tokens += V * n_in
|
| 136 |
+
|
| 137 |
+
spec = {
|
| 138 |
+
"ordered_feature_names": ordered_feature_names,
|
| 139 |
+
"features": features,
|
| 140 |
+
"groups": groups,
|
| 141 |
+
"total_numeric_tokens": int(total_numeric_tokens),
|
| 142 |
+
"group_token_offsets": group_token_offsets, # keys are strings to be JSON-friendly
|
| 143 |
+
}
|
| 144 |
+
return spec
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ============================================================
|
| 148 |
+
# Core modules
|
| 149 |
+
# ============================================================
|
| 150 |
+
|
| 151 |
+
class NumericIdEmbedding(nn.Module):
|
| 152 |
+
"""
|
| 153 |
+
Per-numeric-column ID embedding in the GLOBAL numeric namespace.
|
| 154 |
+
Broadcast each global column id vector across its n_in tokens.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, num_numeric_cols: int, hidden_size: int):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.num_numeric_cols = int(num_numeric_cols)
|
| 160 |
+
self.hidden_size = int(hidden_size)
|
| 161 |
+
self.emb = nn.Embedding(self.num_numeric_cols, self.hidden_size)
|
| 162 |
+
|
| 163 |
+
def forward(self, global_col_ids: torch.LongTensor, batch_size: int, n_in: int) -> torch.Tensor:
|
| 164 |
+
"""
|
| 165 |
+
global_col_ids: [V] in global numeric namespace
|
| 166 |
+
returns: [B, V*n_in, H]
|
| 167 |
+
"""
|
| 168 |
+
if global_col_ids.dim() != 1:
|
| 169 |
+
raise ValueError(f"global_col_ids must be [V], got {tuple(global_col_ids.shape)}")
|
| 170 |
+
|
| 171 |
+
V = global_col_ids.numel()
|
| 172 |
+
n_in = int(n_in)
|
| 173 |
+
|
| 174 |
+
id_vec = self.emb(global_col_ids) # [V, H]
|
| 175 |
+
id_vec = id_vec.view(1, V, 1, self.hidden_size).expand(batch_size, V, n_in, self.hidden_size)
|
| 176 |
+
return id_vec.reshape(batch_size, V * n_in, self.hidden_size)
|
| 177 |
+
|
| 178 |
+
def init_weights(self, std: float = 0.02):
|
| 179 |
+
nn.init.normal_(self.emb.weight, std=std)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class NumericMaskEmbedding(nn.Module):
|
| 183 |
+
"""
|
| 184 |
+
Per-bucket numeric mask embedding.
|
| 185 |
+
Local to one (n_in) group / bucket.
|
| 186 |
+
|
| 187 |
+
Parameter shape:
|
| 188 |
+
[num_bucket_cols, n_in, H]
|
| 189 |
+
|
| 190 |
+
So missing numeric columns are represented by:
|
| 191 |
+
(bucket-local column index, sub-token index)
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(self, num_bucket_cols: int, n_in: int, hidden_size: int):
|
| 195 |
+
super().__init__()
|
| 196 |
+
self.num_bucket_cols = int(num_bucket_cols)
|
| 197 |
+
self.n_in = int(n_in)
|
| 198 |
+
self.hidden_size = int(hidden_size)
|
| 199 |
+
|
| 200 |
+
self.emb = nn.Parameter(
|
| 201 |
+
torch.empty(self.num_bucket_cols, self.n_in, self.hidden_size)
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
def forward(self, local_col_ids: torch.LongTensor, batch_size: int) -> torch.Tensor:
|
| 205 |
+
"""
|
| 206 |
+
local_col_ids: [V] bucket-local ids, usually 0 to V-1
|
| 207 |
+
returns: [B, V*n_in, H]
|
| 208 |
+
"""
|
| 209 |
+
if local_col_ids.dim() != 1:
|
| 210 |
+
raise ValueError(f"local_col_ids must be [V], got {tuple(local_col_ids.shape)}")
|
| 211 |
+
|
| 212 |
+
V = local_col_ids.numel()
|
| 213 |
+
mask_vec = self.emb[local_col_ids] # [V, n_in, H]
|
| 214 |
+
mask_vec = mask_vec.unsqueeze(0).expand(batch_size, V, self.n_in, self.hidden_size)
|
| 215 |
+
return mask_vec.reshape(batch_size, V * self.n_in, self.hidden_size)
|
| 216 |
+
|
| 217 |
+
def init_weights(self, std: float = 0.02):
|
| 218 |
+
nn.init.normal_(self.emb, std=std)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
class NumericEmbedding(nn.Module):
|
| 222 |
+
"""
|
| 223 |
+
Full numeric embedding for all numeric columns described by numeric_vocab.json.
|
| 224 |
+
|
| 225 |
+
Forward expects bucketed input:
|
| 226 |
+
values_by_nin: { n_in: x[B, V, n_in] }
|
| 227 |
+
where V must match the feature count and order of that n_in group.
|
| 228 |
+
|
| 229 |
+
Output token ordering:
|
| 230 |
+
groups by n_in ascending (as stored in spec["groups"]),
|
| 231 |
+
within each group by feature_names order.
|
| 232 |
+
"""
|
| 233 |
+
|
| 234 |
+
def __init__(self, hidden_size: int, numeric_vocab_json: str, middle_size: Optional[int] = None):
|
| 235 |
+
super().__init__()
|
| 236 |
+
self.hidden_size = int(hidden_size)
|
| 237 |
+
self.middle_size = None if middle_size is None else int(middle_size)
|
| 238 |
+
|
| 239 |
+
spec = load_json(numeric_vocab_json)
|
| 240 |
+
self.ordered_feature_names: List[str] = list(spec["ordered_feature_names"])
|
| 241 |
+
self.features: Dict[str, Dict] = dict(spec["features"])
|
| 242 |
+
self.groups: List[Dict] = list(spec["groups"])
|
| 243 |
+
self.total_numeric_tokens = int(spec.get("total_numeric_tokens", -1))
|
| 244 |
+
|
| 245 |
+
num_cols = len(self.ordered_feature_names)
|
| 246 |
+
|
| 247 |
+
# Global numeric namespace id embedding
|
| 248 |
+
self.id_emb = NumericIdEmbedding(
|
| 249 |
+
num_numeric_cols=num_cols,
|
| 250 |
+
hidden_size=self.hidden_size,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Per-group mask embedding
|
| 254 |
+
self.mask_emb = nn.ModuleDict()
|
| 255 |
+
|
| 256 |
+
# Per-group value embedding
|
| 257 |
+
self.group_mlps = nn.ModuleList()
|
| 258 |
+
|
| 259 |
+
self.group_nins: List[int] = []
|
| 260 |
+
self._num_groups = len(self.groups)
|
| 261 |
+
|
| 262 |
+
# Optional: useful for debugging / downstream checks
|
| 263 |
+
self.group_sizes: List[int] = []
|
| 264 |
+
|
| 265 |
+
# Build one block per group
|
| 266 |
+
for gi, g in enumerate(self.groups):
|
| 267 |
+
n_in = int(g["n_in"])
|
| 268 |
+
names = list(g["feature_names"])
|
| 269 |
+
V = len(names)
|
| 270 |
+
|
| 271 |
+
self.group_nins.append(n_in)
|
| 272 |
+
self.group_sizes.append(V)
|
| 273 |
+
|
| 274 |
+
# ---- spec consistency check
|
| 275 |
+
# group_index and index_within_group in features must match groups[gi]["feature_names"] order
|
| 276 |
+
local_ids = []
|
| 277 |
+
for local_idx, nm in enumerate(names):
|
| 278 |
+
f = self.features[nm]
|
| 279 |
+
|
| 280 |
+
if int(f["group_index"]) != gi:
|
| 281 |
+
raise ValueError(
|
| 282 |
+
f"Feature {nm} has group_index={f['group_index']}, expected {gi}"
|
| 283 |
+
)
|
| 284 |
+
if int(f["n_in"]) != n_in:
|
| 285 |
+
raise ValueError(
|
| 286 |
+
f"Feature {nm} has n_in={f['n_in']}, expected {n_in}"
|
| 287 |
+
)
|
| 288 |
+
if int(f["index_within_group"]) != local_idx:
|
| 289 |
+
raise ValueError(
|
| 290 |
+
f"Feature {nm} has index_within_group={f['index_within_group']}, expected {local_idx}"
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
local_ids.append(int(f["index_within_group"]))
|
| 294 |
+
|
| 295 |
+
# strict check: local ids must be exactly 0 to V-1 with no gap / no duplicate
|
| 296 |
+
if sorted(local_ids) != list(range(V)):
|
| 297 |
+
raise ValueError(
|
| 298 |
+
f"Group gi={gi}, n_in={n_in} has invalid index_within_group set: "
|
| 299 |
+
f"got {sorted(local_ids)}, expected {list(range(V))}"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# ---- observed value path: bucket-local ordering
|
| 303 |
+
self.group_mlps.append(
|
| 304 |
+
GroupedMLP(
|
| 305 |
+
n_var=V,
|
| 306 |
+
n_in=n_in,
|
| 307 |
+
n_out=n_in * self.hidden_size,
|
| 308 |
+
middle_size=self.middle_size,
|
| 309 |
+
)
|
| 310 |
+
)
|
| 311 |
+
|
| 312 |
+
# ---- global ids for NumericIdEmbedding
|
| 313 |
+
global_col_ids = [int(self.features[nm]["col_id"]) for nm in names]
|
| 314 |
+
self.register_buffer(
|
| 315 |
+
f"group_global_col_ids_{gi}",
|
| 316 |
+
torch.tensor(global_col_ids, dtype=torch.long),
|
| 317 |
+
persistent=True,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
# ---- local ids for NumericMaskEmbedding
|
| 321 |
+
local_col_ids = [int(self.features[nm]["index_within_group"]) for nm in names]
|
| 322 |
+
self.register_buffer(
|
| 323 |
+
f"group_local_col_ids_{gi}",
|
| 324 |
+
torch.tensor(local_col_ids, dtype=torch.long),
|
| 325 |
+
persistent=True,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
# one mask embedding per bucket
|
| 329 |
+
self.mask_emb[str(n_in)] = NumericMaskEmbedding(
|
| 330 |
+
num_bucket_cols=V,
|
| 331 |
+
n_in=n_in,
|
| 332 |
+
hidden_size=self.hidden_size,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if self.total_numeric_tokens < 0:
|
| 336 |
+
self.total_numeric_tokens = sum(
|
| 337 |
+
len(g["feature_names"]) * int(g["n_in"]) for g in self.groups
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def init_weights(self, std: float = 0.02):
|
| 341 |
+
self.id_emb.init_weights(std=std)
|
| 342 |
+
|
| 343 |
+
for _, mask_mod in self.mask_emb.items():
|
| 344 |
+
mask_mod.init_weights(std=std)
|
| 345 |
+
|
| 346 |
+
for mlp in self.group_mlps:
|
| 347 |
+
mlp.init_weights(std=std)
|
| 348 |
+
|
| 349 |
+
def forward(
|
| 350 |
+
self,
|
| 351 |
+
values_by_nin: Dict[int, torch.Tensor],
|
| 352 |
+
valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None,
|
| 353 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 354 |
+
"""
|
| 355 |
+
Args:
|
| 356 |
+
values_by_nin:
|
| 357 |
+
{ n_in: x } where x is [B, V, n_in]
|
| 358 |
+
Missing numeric values are assumed already filled in x.
|
| 359 |
+
|
| 360 |
+
valid_positions_by_nin (optional):
|
| 361 |
+
{ n_in: valid_cols } where valid_cols is BoolTensor [B, V]
|
| 362 |
+
True means this COLUMN is observed/valid.
|
| 363 |
+
|
| 364 |
+
Note:
|
| 365 |
+
This is COLUMN-level mask, not token-level.
|
| 366 |
+
It is expanded to token-level by repeating across n_in.
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
tokens: [B, total_numeric_tokens, H]
|
| 370 |
+
token_mask: [B, total_numeric_tokens] (1=valid, 0=missing)
|
| 371 |
+
"""
|
| 372 |
+
outs = []
|
| 373 |
+
masks = []
|
| 374 |
+
batch_size = None
|
| 375 |
+
|
| 376 |
+
for gi, n_in in enumerate(self.group_nins):
|
| 377 |
+
if n_in not in values_by_nin:
|
| 378 |
+
raise KeyError(f"Missing bucket input for n_in={n_in}")
|
| 379 |
+
|
| 380 |
+
x = values_by_nin[n_in] # [B, V, n_in]
|
| 381 |
+
if x.dim() != 3 or x.size(-1) != n_in:
|
| 382 |
+
raise ValueError(f"Bucket n_in={n_in} expects x [B,V,{n_in}], got {tuple(x.shape)}")
|
| 383 |
+
|
| 384 |
+
if batch_size is None:
|
| 385 |
+
batch_size = x.size(0)
|
| 386 |
+
elif x.size(0) != batch_size:
|
| 387 |
+
raise ValueError("All buckets must share the same batch size")
|
| 388 |
+
|
| 389 |
+
B, V, _ = x.shape
|
| 390 |
+
|
| 391 |
+
expected_V = self.group_sizes[gi]
|
| 392 |
+
if V != expected_V:
|
| 393 |
+
raise ValueError(
|
| 394 |
+
f"Bucket n_in={n_in} expects V={expected_V}, got V={V}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# column-level valid mask [B, V]
|
| 398 |
+
if valid_positions_by_nin is None:
|
| 399 |
+
valid_cols = torch.ones((B, V), dtype=torch.bool, device=x.device)
|
| 400 |
+
else:
|
| 401 |
+
if n_in not in valid_positions_by_nin:
|
| 402 |
+
raise KeyError(f"Missing valid mask for bucket n_in={n_in}")
|
| 403 |
+
|
| 404 |
+
valid_cols = valid_positions_by_nin[n_in]
|
| 405 |
+
if valid_cols.dtype != torch.bool:
|
| 406 |
+
raise ValueError(
|
| 407 |
+
f"valid_positions_by_nin[{n_in}] must be bool tensor, got {valid_cols.dtype}"
|
| 408 |
+
)
|
| 409 |
+
if valid_cols.shape != (B, V):
|
| 410 |
+
raise ValueError(
|
| 411 |
+
f"valid_positions_by_nin[{n_in}] must be [B,V]=[{B},{V}], got {tuple(valid_cols.shape)}"
|
| 412 |
+
)
|
| 413 |
+
valid_cols = valid_cols.to(device=x.device)
|
| 414 |
+
|
| 415 |
+
# ---- observed numeric value embedding
|
| 416 |
+
mlp = self.group_mlps[gi]
|
| 417 |
+
param = next(mlp.parameters())
|
| 418 |
+
x = x.to(device=param.device, dtype=param.dtype)
|
| 419 |
+
|
| 420 |
+
# [B, V, n_in] -> [B, V, n_in*H]
|
| 421 |
+
y = mlp(x)
|
| 422 |
+
|
| 423 |
+
# [B, V, n_in*H] -> [B, V*n_in, H]
|
| 424 |
+
y_tok = y.view(B, V, n_in, self.hidden_size).reshape(B, V * n_in, self.hidden_size)
|
| 425 |
+
|
| 426 |
+
# [B, V] -> [B, V*n_in]
|
| 427 |
+
valid_tok = valid_cols.unsqueeze(-1).expand(B, V, n_in).reshape(B, V * n_in)
|
| 428 |
+
|
| 429 |
+
# ---- missing replacement: bucket-local mask embedding
|
| 430 |
+
local_col_ids = getattr(self, f"group_local_col_ids_{gi}") # [V]
|
| 431 |
+
mask_tok = self.mask_emb[str(n_in)](local_col_ids, batch_size=B)
|
| 432 |
+
|
| 433 |
+
if (~valid_tok).any():
|
| 434 |
+
y_tok = torch.where(
|
| 435 |
+
valid_tok.unsqueeze(-1),
|
| 436 |
+
y_tok,
|
| 437 |
+
mask_tok,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# ---- add global numeric column id embedding
|
| 441 |
+
global_col_ids = getattr(self, f"group_global_col_ids_{gi}") # [V]
|
| 442 |
+
y_tok = y_tok + self.id_emb(global_col_ids, batch_size=B, n_in=n_in)
|
| 443 |
+
|
| 444 |
+
token_mask = valid_tok.to(dtype=torch.long)
|
| 445 |
+
|
| 446 |
+
outs.append(y_tok)
|
| 447 |
+
masks.append(token_mask)
|
| 448 |
+
|
| 449 |
+
tokens = torch.cat(outs, dim=1)
|
| 450 |
+
token_mask = torch.cat(masks, dim=1)
|
| 451 |
+
|
| 452 |
+
if token_mask.shape[:2] != tokens.shape[:2]:
|
| 453 |
+
raise RuntimeError("token_mask shape mismatch with tokens")
|
| 454 |
+
|
| 455 |
+
return tokens, token_mask
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
# ============================================================
|
| 459 |
+
# DEMO
|
| 460 |
+
# ============================================================
|
| 461 |
+
|
| 462 |
+
def _demo_main():
|
| 463 |
+
import argparse
|
| 464 |
+
|
| 465 |
+
parser = argparse.ArgumentParser()
|
| 466 |
+
parser.add_argument("--tabular_meta", type=str, default="data/tabular_meta.json")
|
| 467 |
+
parser.add_argument("--numeric_vocab_json", type=str, default="data/numeric_vocab.json")
|
| 468 |
+
parser.add_argument("--hidden_size", type=int, default=768)
|
| 469 |
+
parser.add_argument("--middle_size", type=int, default=-1,
|
| 470 |
+
help="If <0 -> one-layer. If >=0 -> two-layer with this middle size.")
|
| 471 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 472 |
+
parser.add_argument("--device", type=str, default=None)
|
| 473 |
+
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
|
| 474 |
+
args = parser.parse_args()
|
| 475 |
+
|
| 476 |
+
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 477 |
+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
| 478 |
+
dtype = dtype_map[args.dtype]
|
| 479 |
+
|
| 480 |
+
meta = load_json(args.tabular_meta)
|
| 481 |
+
|
| 482 |
+
spec = build_numeric_vocab_spec_from_meta(meta)
|
| 483 |
+
save_json(spec, args.numeric_vocab_json)
|
| 484 |
+
print(f"Saved numeric vocab spec to: {args.numeric_vocab_json}")
|
| 485 |
+
print(f"Groups (n_in -> V):", {g["n_in"]: len(g["feature_names"]) for g in spec["groups"]})
|
| 486 |
+
print("total_numeric_tokens:", spec["total_numeric_tokens"])
|
| 487 |
+
print("group_token_offsets:", spec["group_token_offsets"])
|
| 488 |
+
|
| 489 |
+
middle_size = None if args.middle_size < 0 else int(args.middle_size)
|
| 490 |
+
model = NumericEmbedding(
|
| 491 |
+
hidden_size=args.hidden_size,
|
| 492 |
+
numeric_vocab_json=args.numeric_vocab_json,
|
| 493 |
+
middle_size=middle_size,
|
| 494 |
+
).to(device=device, dtype=dtype)
|
| 495 |
+
model.init_weights()
|
| 496 |
+
model.eval()
|
| 497 |
+
|
| 498 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 499 |
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 500 |
+
print(f"Total parameters (NumericEmbedding): {total_params:,} (trainable: {trainable_params:,})")
|
| 501 |
+
|
| 502 |
+
# create demo inputs bucketed by n_in
|
| 503 |
+
B = args.batch_size
|
| 504 |
+
values_by_nin: Dict[int, torch.Tensor] = {}
|
| 505 |
+
valid_positions_by_nin: Dict[int, torch.Tensor] = {}
|
| 506 |
+
|
| 507 |
+
for g in spec["groups"]:
|
| 508 |
+
n_in = int(g["n_in"])
|
| 509 |
+
V = len(g["feature_names"])
|
| 510 |
+
|
| 511 |
+
# random numeric inputs
|
| 512 |
+
x = torch.randn(B, V, n_in, device=device, dtype=dtype)
|
| 513 |
+
values_by_nin[n_in] = x
|
| 514 |
+
|
| 515 |
+
# Build valid mask (column-level)
|
| 516 |
+
# shape: [B, V], True = valid
|
| 517 |
+
valid_cols = torch.ones((B, V), dtype=torch.bool, device=device)
|
| 518 |
+
|
| 519 |
+
# Mark first sample's first 2 columns as invalid
|
| 520 |
+
num_to_invalidate = min(2, V)
|
| 521 |
+
valid_cols[0, :num_to_invalidate] = False
|
| 522 |
+
|
| 523 |
+
valid_positions_by_nin[n_in] = valid_cols
|
| 524 |
+
|
| 525 |
+
with torch.no_grad():
|
| 526 |
+
out, mask = model(values_by_nin, valid_positions_by_nin)
|
| 527 |
+
|
| 528 |
+
print("Buckets:", {k: tuple(v.shape) for k, v in values_by_nin.items()})
|
| 529 |
+
print("Output tokens:", tuple(out.shape), out.dtype, out.device) # [B, total_numeric_tokens, H]
|
| 530 |
+
print("Masks:", tuple(mask.shape), mask.dtype, mask.device) # [B, total_numeric_tokens]
|
| 531 |
+
|
| 532 |
+
# ---- Inspect first sample
|
| 533 |
+
print("\nFirst sample mask (first 5 tokens):")
|
| 534 |
+
print(mask[0, :5])
|
| 535 |
+
|
| 536 |
+
print("\nFirst sample token L2 norms (first 5 tokens):")
|
| 537 |
+
print(out[0, :5].norm(dim=-1))
|
| 538 |
+
|
| 539 |
+
print("\nSecond sample mask (first 5 tokens):")
|
| 540 |
+
print(mask[1, :5])
|
| 541 |
+
|
| 542 |
+
print("\nSecond sample token L2 norms (first 5 tokens):")
|
| 543 |
+
print(out[1, :5].norm(dim=-1))
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
if __name__ == "__main__":
|
| 547 |
+
_demo_main()
|
modelling/embed_vision_gemma3n.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# embed_vision_gemma3n.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Optional, Tuple, Dict
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from safetensors.torch import load_file as safetensors_load_file
|
| 10 |
+
from transformers import AutoConfig, AutoModel
|
| 11 |
+
from transformers.models.gemma3n.modeling_gemma3n import Gemma3nMultimodalEmbedder # noqa
|
| 12 |
+
|
| 13 |
+
from utils import load_json
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _split_state_dict_from_tmp(sd: Dict[str, torch.Tensor]) \
|
| 17 |
+
-> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
|
| 18 |
+
"""
|
| 19 |
+
Model extractor saved tmp.state_dict() where tmp has attributes:
|
| 20 |
+
- vision_tower
|
| 21 |
+
- embed_vision (optional)
|
| 22 |
+
So keys look like:
|
| 23 |
+
- vision_tower.xxx
|
| 24 |
+
- embed_vision.xxx
|
| 25 |
+
"""
|
| 26 |
+
vt = {}
|
| 27 |
+
ev = {}
|
| 28 |
+
for k, v in sd.items():
|
| 29 |
+
if k.startswith("vision_tower."):
|
| 30 |
+
vt[k[len("vision_tower."):]] = v
|
| 31 |
+
elif k.startswith("embed_vision."):
|
| 32 |
+
ev[k[len("embed_vision."):]] = v
|
| 33 |
+
return vt, ev
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ============================================================
|
| 37 |
+
# Optional lightweight learnable token reducer
|
| 38 |
+
# ============================================================
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class VisionTokenReducer(nn.Module):
|
| 42 |
+
"""
|
| 43 |
+
Perceiver-style learnable cross-attention pooling with optional bottleneck.
|
| 44 |
+
|
| 45 |
+
Base (no bottleneck):
|
| 46 |
+
[B,T,D] -> [B,K,D]
|
| 47 |
+
|
| 48 |
+
Bottleneck mode (bottleneck_dim=d):
|
| 49 |
+
[B,T,D] -> down -> [B,T,d] -> cross-attn -> [B,K,d] -> (optional up) -> [B,K,D]
|
| 50 |
+
|
| 51 |
+
Notes:
|
| 52 |
+
- num_heads does NOT change parameter count of MultiheadAttention (depends on D only).
|
| 53 |
+
- perform_norm_latent controls whether to pre-norm the learnable latent queries.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
vision_dim: int,
|
| 59 |
+
num_output_tokens: int,
|
| 60 |
+
num_heads: int = 4,
|
| 61 |
+
perform_norm_latent: bool = True,
|
| 62 |
+
bottleneck_dim: Optional[int] = None,
|
| 63 |
+
project_back: bool = True,
|
| 64 |
+
):
|
| 65 |
+
super().__init__()
|
| 66 |
+
|
| 67 |
+
self.vision_dim = int(vision_dim)
|
| 68 |
+
self.num_output_tokens = int(num_output_tokens)
|
| 69 |
+
self.num_heads = int(num_heads)
|
| 70 |
+
self.perform_norm_latent = bool(perform_norm_latent)
|
| 71 |
+
|
| 72 |
+
self.bottleneck_dim = None if bottleneck_dim is None else int(bottleneck_dim)
|
| 73 |
+
self.project_back = bool(project_back)
|
| 74 |
+
|
| 75 |
+
# Decide the attention working dimension: D (base) or d (bottleneck)
|
| 76 |
+
attn_dim = self.vision_dim if self.bottleneck_dim is None else self.bottleneck_dim
|
| 77 |
+
if attn_dim % self.num_heads != 0:
|
| 78 |
+
raise ValueError(f"embed_dim ({attn_dim}) must be divisible by num_heads ({self.num_heads})")
|
| 79 |
+
|
| 80 |
+
# Optional projection layers for bottleneck mode
|
| 81 |
+
if self.bottleneck_dim is None:
|
| 82 |
+
self.down = None
|
| 83 |
+
self.up = None
|
| 84 |
+
else:
|
| 85 |
+
# bias=False keeps it lightweight; switch to True if you prefer
|
| 86 |
+
self.down = nn.Linear(self.vision_dim, attn_dim, bias=False)
|
| 87 |
+
self.up = nn.Linear(attn_dim, self.vision_dim, bias=False) if self.project_back else None
|
| 88 |
+
|
| 89 |
+
# Learnable latent tokens (K, attn_dim)
|
| 90 |
+
self.latents = nn.Parameter(torch.randn(self.num_output_tokens, attn_dim) * 0.02)
|
| 91 |
+
|
| 92 |
+
# Separate norms: typically more stable than sharing one LN
|
| 93 |
+
self.norm_latents = nn.LayerNorm(attn_dim)
|
| 94 |
+
self.norm_x = nn.LayerNorm(attn_dim)
|
| 95 |
+
|
| 96 |
+
# Cross-attention: query=latents, key/value=x
|
| 97 |
+
self.attn = nn.MultiheadAttention(
|
| 98 |
+
embed_dim=attn_dim,
|
| 99 |
+
num_heads=self.num_heads,
|
| 100 |
+
batch_first=True,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def init_weights(self, std: float = 0.02):
|
| 104 |
+
# Optional bottleneck projections
|
| 105 |
+
if self.down is not None:
|
| 106 |
+
nn.init.normal_(self.down.weight, std=std)
|
| 107 |
+
if self.up is not None:
|
| 108 |
+
nn.init.normal_(self.up.weight, std=std)
|
| 109 |
+
|
| 110 |
+
# Learnable latent queries
|
| 111 |
+
nn.init.normal_(self.latents, std=std)
|
| 112 |
+
|
| 113 |
+
# LayerNorm
|
| 114 |
+
nn.init.ones_(self.norm_latents.weight)
|
| 115 |
+
nn.init.zeros_(self.norm_latents.bias)
|
| 116 |
+
nn.init.ones_(self.norm_x.weight)
|
| 117 |
+
nn.init.zeros_(self.norm_x.bias)
|
| 118 |
+
|
| 119 |
+
# MultiheadAttention: use PyTorch's own reset only
|
| 120 |
+
self.attn._reset_parameters() # noqa
|
| 121 |
+
|
| 122 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 123 |
+
"""
|
| 124 |
+
Args:
|
| 125 |
+
x: [B, T, D] where D == vision_dim
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
out: [B, K, D] if (bottleneck_dim is None) or project_back=True
|
| 129 |
+
[B, K, d] if bottleneck_dim is not None and project_back=False
|
| 130 |
+
"""
|
| 131 |
+
if x.dim() != 3:
|
| 132 |
+
raise ValueError(f"Expected x [B,T,D], got {tuple(x.shape)}")
|
| 133 |
+
if x.size(-1) != self.vision_dim:
|
| 134 |
+
raise ValueError(f"Expected last dim D={self.vision_dim}, got {x.size(-1)}")
|
| 135 |
+
|
| 136 |
+
B = x.size(0)
|
| 137 |
+
|
| 138 |
+
# Bottleneck projection if enabled
|
| 139 |
+
if self.down is not None:
|
| 140 |
+
x = self.down(x) # [B,T,d]
|
| 141 |
+
|
| 142 |
+
# Expand learnable latents across batch
|
| 143 |
+
latents = self.latents.unsqueeze(0).expand(B, -1, -1) # [B,K,attn_dim]
|
| 144 |
+
|
| 145 |
+
# Pre-norm (optional for latents, always for input tokens)
|
| 146 |
+
if self.perform_norm_latent:
|
| 147 |
+
latents = self.norm_latents(latents)
|
| 148 |
+
x = self.norm_x(x)
|
| 149 |
+
|
| 150 |
+
# Cross-attention pooling
|
| 151 |
+
out, _ = self.attn(query=latents, key=x, value=x) # [B,K,attn_dim]
|
| 152 |
+
|
| 153 |
+
# Project back to original dim if requested
|
| 154 |
+
if self.up is not None:
|
| 155 |
+
out = self.up(out) # [B,K,D]
|
| 156 |
+
|
| 157 |
+
return out
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
# ============================================================
|
| 161 |
+
# Main body
|
| 162 |
+
# ============================================================
|
| 163 |
+
|
| 164 |
+
class Gemma3nVisionFeatureExtractor(nn.Module):
|
| 165 |
+
"""
|
| 166 |
+
Vision-only feature extractor for Gemma-3n that matches transformers' Gemma3nModel.get_image_features().
|
| 167 |
+
|
| 168 |
+
Input: pixel_values [B, 3, H, W]
|
| 169 |
+
Output: image_features [B, vision_soft_tokens_per_image, text_hidden_size]
|
| 170 |
+
"""
|
| 171 |
+
|
| 172 |
+
def __init__(
|
| 173 |
+
self,
|
| 174 |
+
vision_tower: nn.Module,
|
| 175 |
+
embed_vision: Optional[nn.Module],
|
| 176 |
+
vision_hidden_size: int,
|
| 177 |
+
vision_soft_tokens_per_image: int,
|
| 178 |
+
text_hidden_size: int,
|
| 179 |
+
num_output_tokens_reduced: Optional[int] = None,
|
| 180 |
+
num_heads_for_token_reduction: int = 4,
|
| 181 |
+
perform_norm_latent_for_token_reduction: bool = True,
|
| 182 |
+
reducer_bottleneck_dim: Optional[int] = None,
|
| 183 |
+
reducer_project_back: bool = True,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.vision_tower = vision_tower
|
| 187 |
+
self.embed_vision = embed_vision
|
| 188 |
+
self.vision_hidden_size = int(vision_hidden_size)
|
| 189 |
+
self.vision_soft_tokens_per_image = int(vision_soft_tokens_per_image)
|
| 190 |
+
self.text_hidden_size = int(text_hidden_size)
|
| 191 |
+
self.has_embed_vision = embed_vision is not None
|
| 192 |
+
|
| 193 |
+
# Freeze vision modules
|
| 194 |
+
self.vision_tower.requires_grad_(False)
|
| 195 |
+
if self.embed_vision is not None:
|
| 196 |
+
self.embed_vision.requires_grad_(False)
|
| 197 |
+
|
| 198 |
+
# Reduce number of tokens
|
| 199 |
+
if num_output_tokens_reduced is not None:
|
| 200 |
+
reducer_dim = text_hidden_size if self.has_embed_vision else vision_hidden_size
|
| 201 |
+
self.reducer = VisionTokenReducer(
|
| 202 |
+
vision_dim=reducer_dim,
|
| 203 |
+
num_output_tokens=num_output_tokens_reduced,
|
| 204 |
+
num_heads=num_heads_for_token_reduction,
|
| 205 |
+
perform_norm_latent=perform_norm_latent_for_token_reduction,
|
| 206 |
+
bottleneck_dim=reducer_bottleneck_dim,
|
| 207 |
+
project_back=reducer_project_back,
|
| 208 |
+
)
|
| 209 |
+
else:
|
| 210 |
+
self.reducer = None
|
| 211 |
+
|
| 212 |
+
def init_weights(self, std: float = 0.02):
|
| 213 |
+
if self.reducer is not None:
|
| 214 |
+
self.reducer.init_weights(std)
|
| 215 |
+
|
| 216 |
+
def get_actual_hidden_dim(self) -> int:
|
| 217 |
+
"""
|
| 218 |
+
Return the actual feature hidden dimension produced by this extractor.
|
| 219 |
+
|
| 220 |
+
The output dimension depends on:
|
| 221 |
+
- whether embed_vision is used
|
| 222 |
+
- whether a reducer is present
|
| 223 |
+
- reducer bottleneck + project_back configuration
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
int: feature hidden size of output tokens
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
# Base dimension before reducer
|
| 230 |
+
base_dim = self.text_hidden_size if self.has_embed_vision else self.vision_hidden_size
|
| 231 |
+
|
| 232 |
+
# No reducer
|
| 233 |
+
if self.reducer is None:
|
| 234 |
+
return base_dim
|
| 235 |
+
|
| 236 |
+
# Reducer without bottleneck
|
| 237 |
+
if self.reducer.bottleneck_dim is None:
|
| 238 |
+
return base_dim
|
| 239 |
+
|
| 240 |
+
# Bottleneck reducer
|
| 241 |
+
if self.reducer.project_back:
|
| 242 |
+
return base_dim
|
| 243 |
+
|
| 244 |
+
# Bottleneck without projection back
|
| 245 |
+
return int(self.reducer.bottleneck_dim)
|
| 246 |
+
|
| 247 |
+
def train(self, mode: bool = True) -> "Gemma3nVisionFeatureExtractor":
|
| 248 |
+
""" Override train(): vision is not trainable"""
|
| 249 |
+
super().train(mode=mode)
|
| 250 |
+
self.vision_tower.eval()
|
| 251 |
+
if self.embed_vision is not None:
|
| 252 |
+
self.embed_vision.eval()
|
| 253 |
+
return self
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
pixel_values: torch.Tensor,
|
| 258 |
+
valid_positions: Optional[torch.Tensor] = None,
|
| 259 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 260 |
+
"""
|
| 261 |
+
Args:
|
| 262 |
+
pixel_values: [B, 3, H, W]
|
| 263 |
+
valid_positions:
|
| 264 |
+
Indicates which samples have valid images.
|
| 265 |
+
Supported formats:
|
| 266 |
+
- BoolTensor [B] where True means "has image"
|
| 267 |
+
- LongTensor [K] with indices of samples that have images
|
| 268 |
+
If None: assume all samples have images.
|
| 269 |
+
|
| 270 |
+
Returns:
|
| 271 |
+
features: [B, T_img, D]
|
| 272 |
+
vision_mask: [B, T_img] (1=valid vision token, 0=masked out)
|
| 273 |
+
"""
|
| 274 |
+
if pixel_values.dim() != 4:
|
| 275 |
+
raise ValueError(f"pixel_values must be [B,3,H,W], got {tuple(pixel_values.shape)}")
|
| 276 |
+
|
| 277 |
+
B = pixel_values.size(0)
|
| 278 |
+
device = next(self.vision_tower.parameters()).device
|
| 279 |
+
dtype = next(self.vision_tower.parameters()).dtype
|
| 280 |
+
|
| 281 |
+
# --------------------------------------------------------
|
| 282 |
+
# Build per-sample valid-image mask
|
| 283 |
+
# --------------------------------------------------------
|
| 284 |
+
if valid_positions is None:
|
| 285 |
+
valid_mask = torch.ones(B, dtype=torch.bool, device=pixel_values.device)
|
| 286 |
+
else:
|
| 287 |
+
if valid_positions.dtype == torch.bool:
|
| 288 |
+
if valid_positions.shape != (B,):
|
| 289 |
+
raise ValueError(f"valid_positions (bool) must be [B], got {tuple(valid_positions.shape)}")
|
| 290 |
+
valid_mask = valid_positions.to(device=pixel_values.device)
|
| 291 |
+
else:
|
| 292 |
+
if valid_positions.dim() != 1:
|
| 293 |
+
raise ValueError(f"valid_positions (indices) must be 1D, got {tuple(valid_positions.shape)}")
|
| 294 |
+
valid_mask = torch.zeros(B, dtype=torch.bool, device=pixel_values.device)
|
| 295 |
+
valid_mask[valid_positions.to(device=pixel_values.device, dtype=torch.long)] = True
|
| 296 |
+
|
| 297 |
+
num_valid = int(valid_mask.sum().item())
|
| 298 |
+
|
| 299 |
+
# --------------------------------------------------------
|
| 300 |
+
# Figure out final output shape in advance
|
| 301 |
+
# --------------------------------------------------------
|
| 302 |
+
if self.reducer is None:
|
| 303 |
+
T_img = self.vision_soft_tokens_per_image
|
| 304 |
+
else:
|
| 305 |
+
T_img = self.reducer.num_output_tokens
|
| 306 |
+
|
| 307 |
+
D_out = self.get_actual_hidden_dim()
|
| 308 |
+
|
| 309 |
+
# vision_mask always returned for full batch
|
| 310 |
+
vision_mask = valid_mask[:, None].expand(B, T_img).to(dtype=torch.long)
|
| 311 |
+
|
| 312 |
+
# Fast path: no valid image at all
|
| 313 |
+
if num_valid == 0:
|
| 314 |
+
features = torch.zeros(B, T_img, D_out, device=device, dtype=dtype)
|
| 315 |
+
return features, vision_mask
|
| 316 |
+
|
| 317 |
+
# --------------------------------------------------------
|
| 318 |
+
# Run only valid samples through frozen vision stack
|
| 319 |
+
# --------------------------------------------------------
|
| 320 |
+
pixel_values_valid = pixel_values[valid_mask].to(device=device, dtype=dtype)
|
| 321 |
+
|
| 322 |
+
with torch.no_grad():
|
| 323 |
+
vision_last = self.vision_tower(
|
| 324 |
+
pixel_values=pixel_values_valid,
|
| 325 |
+
do_pooling=False,
|
| 326 |
+
return_dict=True,
|
| 327 |
+
).last_hidden_state
|
| 328 |
+
|
| 329 |
+
if vision_last.dim() != 4:
|
| 330 |
+
raise RuntimeError(f"Expected vision last_hidden_state (B,C,h,w), got {tuple(vision_last.shape)}")
|
| 331 |
+
|
| 332 |
+
Bv, C, h, w = vision_last.shape
|
| 333 |
+
if Bv != num_valid:
|
| 334 |
+
raise RuntimeError("Batch size mismatch between valid pixel_values and vision_last")
|
| 335 |
+
if C != self.vision_hidden_size:
|
| 336 |
+
raise RuntimeError(f"Expected vision_hidden_size={self.vision_hidden_size}, got C={C}")
|
| 337 |
+
if h * w != self.vision_soft_tokens_per_image:
|
| 338 |
+
raise RuntimeError(
|
| 339 |
+
f"Expected h*w={self.vision_soft_tokens_per_image}, got {h * w}. "
|
| 340 |
+
f"Check processor image size/crop or config."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
# (Bv, C, h, w) -> (Bv, C, HW) -> (Bv, HW, C)
|
| 344 |
+
vision_tokens = vision_last.reshape(Bv, C, self.vision_soft_tokens_per_image).permute(0, 2, 1).contiguous()
|
| 345 |
+
|
| 346 |
+
# Scale by sqrt(C) (matches Gemma codepath)
|
| 347 |
+
vision_tokens = vision_tokens * (self.vision_hidden_size ** 0.5)
|
| 348 |
+
|
| 349 |
+
# --------------------------------------------------------
|
| 350 |
+
# Extract valid-image features only
|
| 351 |
+
# --------------------------------------------------------
|
| 352 |
+
if not self.has_embed_vision:
|
| 353 |
+
valid_features = vision_tokens # [Bv, HW, C]
|
| 354 |
+
if self.reducer is not None:
|
| 355 |
+
valid_features = self.reducer(valid_features) # [Bv, T_img, C or d]
|
| 356 |
+
else:
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
valid_features = self.embed_vision(inputs_embeds=vision_tokens)
|
| 359 |
+
|
| 360 |
+
if valid_features.shape != (Bv, self.vision_soft_tokens_per_image, self.text_hidden_size):
|
| 361 |
+
raise RuntimeError(
|
| 362 |
+
f"Bad output shape {tuple(valid_features.shape)}; expected "
|
| 363 |
+
f"({Bv}, {self.vision_soft_tokens_per_image}, {self.text_hidden_size})"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if self.reducer is not None:
|
| 367 |
+
valid_features = self.reducer(valid_features)
|
| 368 |
+
|
| 369 |
+
# --------------------------------------------------------
|
| 370 |
+
# Scatter back to full batch; invalid samples stay zero
|
| 371 |
+
# --------------------------------------------------------
|
| 372 |
+
if valid_features.size(1) != T_img:
|
| 373 |
+
raise RuntimeError(f"T_img mismatch: expected {T_img}, got {valid_features.size(1)}")
|
| 374 |
+
if valid_features.size(2) != D_out:
|
| 375 |
+
raise RuntimeError(f"D_out mismatch: expected {D_out}, got {valid_features.size(2)}")
|
| 376 |
+
|
| 377 |
+
features = torch.zeros(B, T_img, D_out, device=valid_features.device, dtype=valid_features.dtype)
|
| 378 |
+
features[valid_mask] = valid_features
|
| 379 |
+
|
| 380 |
+
return features, vision_mask
|
| 381 |
+
|
| 382 |
+
@classmethod
|
| 383 |
+
def from_pretrained_vision_only_dir(
|
| 384 |
+
cls,
|
| 385 |
+
model_dir: str,
|
| 386 |
+
map_location: str = "cpu",
|
| 387 |
+
num_output_tokens_reduced: Optional[int] = None,
|
| 388 |
+
num_heads_for_token_reduction: int = 4,
|
| 389 |
+
perform_norm_latent_for_token_reduction: bool = True,
|
| 390 |
+
reducer_bottleneck_dim: Optional[int] = None,
|
| 391 |
+
reducer_project_back: bool = True,
|
| 392 |
+
) -> "Gemma3nVisionFeatureExtractor":
|
| 393 |
+
weights_path = os.path.join(model_dir, "model.safetensors")
|
| 394 |
+
if not os.path.isfile(weights_path):
|
| 395 |
+
raise FileNotFoundError(f"Missing weights: {weights_path}")
|
| 396 |
+
|
| 397 |
+
ve_cfg_path = os.path.join(model_dir, "vision_extractor_config.json")
|
| 398 |
+
if not os.path.isfile(ve_cfg_path):
|
| 399 |
+
raise FileNotFoundError(f"Missing {ve_cfg_path}")
|
| 400 |
+
ve_cfg = load_json(ve_cfg_path)
|
| 401 |
+
|
| 402 |
+
vision_soft_tokens_per_image = int(ve_cfg.get("vision_soft_tokens_per_image", 256))
|
| 403 |
+
vision_hidden_size = int(ve_cfg.get("vision_hidden_size", -1))
|
| 404 |
+
text_hidden_size = int(ve_cfg.get("text_hidden_size", -1))
|
| 405 |
+
has_embed_vision = bool(ve_cfg.get("has_embed_vision", True))
|
| 406 |
+
|
| 407 |
+
if vision_hidden_size <= 0:
|
| 408 |
+
raise ValueError("vision_hidden_size missing/invalid in vision_extractor_config.json")
|
| 409 |
+
if has_embed_vision and text_hidden_size <= 0:
|
| 410 |
+
raise ValueError("text_hidden_size missing/invalid in vision_extractor_config.json")
|
| 411 |
+
|
| 412 |
+
cfg = AutoConfig.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
|
| 413 |
+
vision_cfg = getattr(cfg, "vision_config", cfg)
|
| 414 |
+
text_cfg = getattr(cfg, "text_config", None)
|
| 415 |
+
|
| 416 |
+
vision_tower = AutoModel.from_config(vision_cfg, trust_remote_code=True)
|
| 417 |
+
|
| 418 |
+
embed_vision = None
|
| 419 |
+
if has_embed_vision:
|
| 420 |
+
if text_cfg is None:
|
| 421 |
+
raise RuntimeError(
|
| 422 |
+
"config.json does not contain text_config, but has_embed_vision=True. "
|
| 423 |
+
"You need a Gemma3nConfig-like config.json in this folder."
|
| 424 |
+
)
|
| 425 |
+
embed_vision = Gemma3nMultimodalEmbedder(vision_cfg, text_cfg)
|
| 426 |
+
|
| 427 |
+
sd = safetensors_load_file(weights_path, device=map_location)
|
| 428 |
+
|
| 429 |
+
vt_sd, ev_sd = _split_state_dict_from_tmp(sd)
|
| 430 |
+
if not vt_sd:
|
| 431 |
+
raise RuntimeError("No vision_tower.* keys found in model.safetensors")
|
| 432 |
+
if has_embed_vision and not ev_sd:
|
| 433 |
+
raise RuntimeError("has_embed_vision=True but no embed_vision.* keys found in model.safetensors")
|
| 434 |
+
|
| 435 |
+
missing_vt, unexpected_vt = vision_tower.load_state_dict(vt_sd, strict=True)
|
| 436 |
+
if missing_vt or unexpected_vt:
|
| 437 |
+
raise RuntimeError(f"vision_tower load mismatch: missing={missing_vt}, unexpected={unexpected_vt}")
|
| 438 |
+
|
| 439 |
+
if has_embed_vision:
|
| 440 |
+
missing_ev, unexpected_ev = embed_vision.load_state_dict(ev_sd, strict=True)
|
| 441 |
+
if missing_ev or unexpected_ev:
|
| 442 |
+
raise RuntimeError(f"embed_vision load mismatch: missing={missing_ev}, unexpected={unexpected_ev}")
|
| 443 |
+
|
| 444 |
+
vision_tower.eval()
|
| 445 |
+
if embed_vision is not None:
|
| 446 |
+
embed_vision.eval()
|
| 447 |
+
|
| 448 |
+
model = cls(
|
| 449 |
+
vision_tower=vision_tower,
|
| 450 |
+
embed_vision=embed_vision,
|
| 451 |
+
vision_hidden_size=vision_hidden_size,
|
| 452 |
+
vision_soft_tokens_per_image=vision_soft_tokens_per_image,
|
| 453 |
+
text_hidden_size=text_hidden_size if has_embed_vision else vision_hidden_size,
|
| 454 |
+
num_output_tokens_reduced=num_output_tokens_reduced,
|
| 455 |
+
num_heads_for_token_reduction=num_heads_for_token_reduction,
|
| 456 |
+
perform_norm_latent_for_token_reduction=perform_norm_latent_for_token_reduction,
|
| 457 |
+
reducer_bottleneck_dim=reducer_bottleneck_dim,
|
| 458 |
+
reducer_project_back=reducer_project_back,
|
| 459 |
+
)
|
| 460 |
+
model.eval()
|
| 461 |
+
return model
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _demo_main():
|
| 465 |
+
import argparse
|
| 466 |
+
from PIL import Image
|
| 467 |
+
from transformers import AutoProcessor
|
| 468 |
+
from pathlib import Path
|
| 469 |
+
|
| 470 |
+
parser = argparse.ArgumentParser()
|
| 471 |
+
parser.add_argument("--model_dir", type=str, default="./model_weights/gemma3n_E2B_vision_only")
|
| 472 |
+
parser.add_argument("--device", type=str, default=None)
|
| 473 |
+
parser.add_argument("--dtype", type=str, default="float32", choices=["bfloat16", "float16", "float32"])
|
| 474 |
+
parser.add_argument("--num_output_tokens_reduced", type=int, default=32)
|
| 475 |
+
parser.add_argument("--reducer_bottleneck_dim", type=int, default=768)
|
| 476 |
+
parser.add_argument("--reducer_project_back", action="store_true")
|
| 477 |
+
args = parser.parse_args()
|
| 478 |
+
|
| 479 |
+
model_dir = str(Path(args.model_dir).resolve())
|
| 480 |
+
|
| 481 |
+
# Force local loading
|
| 482 |
+
processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True, local_files_only=True)
|
| 483 |
+
|
| 484 |
+
model = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir(
|
| 485 |
+
model_dir=model_dir,
|
| 486 |
+
map_location="cpu",
|
| 487 |
+
num_output_tokens_reduced=args.num_output_tokens_reduced,
|
| 488 |
+
num_heads_for_token_reduction=4,
|
| 489 |
+
reducer_bottleneck_dim=args.reducer_bottleneck_dim,
|
| 490 |
+
reducer_project_back=args.reducer_project_back,
|
| 491 |
+
)
|
| 492 |
+
model.init_weights()
|
| 493 |
+
model.to(device=args.device, dtype=args.dtype)
|
| 494 |
+
model.eval()
|
| 495 |
+
|
| 496 |
+
def count_params(module):
|
| 497 |
+
return sum(p.numel() for p in module.parameters())
|
| 498 |
+
|
| 499 |
+
vision_params = count_params(model.vision_tower)
|
| 500 |
+
|
| 501 |
+
embed_params = 0
|
| 502 |
+
if model.has_embed_vision and model.embed_vision is not None:
|
| 503 |
+
embed_params = count_params(model.embed_vision)
|
| 504 |
+
|
| 505 |
+
reducer_params = 0
|
| 506 |
+
if model.reducer is not None:
|
| 507 |
+
reducer_params = count_params(model.reducer)
|
| 508 |
+
|
| 509 |
+
frozen_params = vision_params + embed_params
|
| 510 |
+
total_params = frozen_params + reducer_params
|
| 511 |
+
|
| 512 |
+
print(f"Vision tower parameters (frozen): {vision_params:,}")
|
| 513 |
+
|
| 514 |
+
if model.has_embed_vision:
|
| 515 |
+
print(f"Embed vision parameters (frozen): {embed_params:,}")
|
| 516 |
+
else:
|
| 517 |
+
print("Embed vision: NONE")
|
| 518 |
+
|
| 519 |
+
if model.reducer is not None:
|
| 520 |
+
print(f"Reducer parameters (trainable): {reducer_params:,}")
|
| 521 |
+
else:
|
| 522 |
+
print("Reducer: NONE")
|
| 523 |
+
|
| 524 |
+
print(f"Total frozen parameters: {frozen_params:,}")
|
| 525 |
+
print(f"Total trainable parameters: {reducer_params:,}")
|
| 526 |
+
print(f"Total parameters: {total_params:,}")
|
| 527 |
+
|
| 528 |
+
img1 = Image.new("RGB", (768, 768), color=(0, 0, 0))
|
| 529 |
+
img2 = Image.new("RGB", (768, 768), color=(255, 255, 255))
|
| 530 |
+
|
| 531 |
+
inputs = processor(
|
| 532 |
+
text=["", ""],
|
| 533 |
+
images=[[img1], [img2]],
|
| 534 |
+
return_tensors="pt",
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
pixel_values = inputs["pixel_values"].to(
|
| 538 |
+
device=next(model.parameters()).device,
|
| 539 |
+
dtype=next(model.parameters()).dtype,
|
| 540 |
+
)
|
| 541 |
+
|
| 542 |
+
print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device)
|
| 543 |
+
|
| 544 |
+
with torch.no_grad():
|
| 545 |
+
feats, masks = model(pixel_values)
|
| 546 |
+
|
| 547 |
+
print("features:", tuple(feats.shape), feats.dtype, feats.device)
|
| 548 |
+
print("masks:", tuple(masks.shape), masks.dtype, masks.device)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
_demo_main()
|
modelling/layer.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# layer.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import math
|
| 5 |
+
from typing import Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F # noqa
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class RMSNorm(nn.Module):
|
| 13 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.eps = float(eps)
|
| 16 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 17 |
+
|
| 18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 19 |
+
# x: [..., dim]
|
| 20 |
+
x_float = x.float()
|
| 21 |
+
rms = x_float.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
|
| 22 |
+
y = (x_float / rms).to(dtype=x.dtype)
|
| 23 |
+
return y * self.weight.to(dtype=x.dtype, device=x.device)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SwiGLU(nn.Module):
|
| 27 |
+
@staticmethod
|
| 28 |
+
def forward(gate: torch.Tensor, up: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
return nn.functional.silu(gate) * up
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TabularImageGQALayer(nn.Module):
|
| 33 |
+
"""
|
| 34 |
+
Pre-norm Transformer block with:
|
| 35 |
+
- Tabular tokens produce Q; tabular+image produce KV (image optional)
|
| 36 |
+
- GQA: num_query_heads is a multiple of num_kv_heads
|
| 37 |
+
- Numeric+categorical must be concatenated before calling this layer (one tabular stream)
|
| 38 |
+
- attention_mask is 1D [B, T_tab] and does not include vision tokens
|
| 39 |
+
- If vision_features is None, attention is tabular-only
|
| 40 |
+
- Vision tokens are not updated (no Q for vision)
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
tabular_dim: int,
|
| 46 |
+
vision_dim: int,
|
| 47 |
+
num_query_heads: int,
|
| 48 |
+
num_kv_heads: int,
|
| 49 |
+
head_dim: int,
|
| 50 |
+
mlp_ratio: float = 4.0,
|
| 51 |
+
dropout: float = 0.0,
|
| 52 |
+
rmsnorm_eps: float = 1e-6,
|
| 53 |
+
):
|
| 54 |
+
super().__init__()
|
| 55 |
+
|
| 56 |
+
if num_query_heads % num_kv_heads != 0:
|
| 57 |
+
raise ValueError("num_query_heads must be a multiple of num_kv_heads")
|
| 58 |
+
|
| 59 |
+
self.tabular_dim = int(tabular_dim)
|
| 60 |
+
self.vision_dim = int(vision_dim)
|
| 61 |
+
self.num_query_heads = int(num_query_heads)
|
| 62 |
+
self.num_kv_heads = int(num_kv_heads)
|
| 63 |
+
self.head_dim = int(head_dim)
|
| 64 |
+
|
| 65 |
+
self.q_dim = self.num_query_heads * self.head_dim
|
| 66 |
+
self.kv_dim = self.num_kv_heads * self.head_dim
|
| 67 |
+
self.group_size = self.num_query_heads // self.num_kv_heads
|
| 68 |
+
|
| 69 |
+
self.attn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps)
|
| 70 |
+
|
| 71 |
+
# Tabular projections (shared for numeric+categorical stream)
|
| 72 |
+
self.q_proj_tab = nn.Linear(self.tabular_dim, self.q_dim, bias=False)
|
| 73 |
+
self.k_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False)
|
| 74 |
+
self.v_proj_tab = nn.Linear(self.tabular_dim, self.kv_dim, bias=False)
|
| 75 |
+
|
| 76 |
+
# Vision KV projections (separate; vision has no Q)
|
| 77 |
+
self.k_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False)
|
| 78 |
+
self.v_proj_img = nn.Linear(self.vision_dim, self.kv_dim, bias=False)
|
| 79 |
+
|
| 80 |
+
self.o_proj = nn.Linear(self.q_dim, self.tabular_dim, bias=False)
|
| 81 |
+
|
| 82 |
+
self.attn_dropout = float(dropout)
|
| 83 |
+
self.resid_dropout = float(dropout)
|
| 84 |
+
|
| 85 |
+
# FFN (LLM-style: gated MLP with SwiGLU)
|
| 86 |
+
self.ffn_norm = RMSNorm(self.tabular_dim, eps=rmsnorm_eps)
|
| 87 |
+
ffn_dim = int(round(self.tabular_dim * float(mlp_ratio)))
|
| 88 |
+
|
| 89 |
+
self.gate_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False)
|
| 90 |
+
self.up_proj = nn.Linear(self.tabular_dim, ffn_dim, bias=False)
|
| 91 |
+
self.down_proj = nn.Linear(ffn_dim, self.tabular_dim, bias=False)
|
| 92 |
+
self.act = SwiGLU()
|
| 93 |
+
|
| 94 |
+
def init_weights(self, std: float = 0.02):
|
| 95 |
+
# RMSNorm
|
| 96 |
+
nn.init.ones_(self.attn_norm.weight)
|
| 97 |
+
nn.init.ones_(self.ffn_norm.weight)
|
| 98 |
+
|
| 99 |
+
# Attention projections
|
| 100 |
+
nn.init.normal_(self.q_proj_tab.weight, std=std)
|
| 101 |
+
nn.init.normal_(self.k_proj_tab.weight, std=std)
|
| 102 |
+
nn.init.normal_(self.v_proj_tab.weight, std=std)
|
| 103 |
+
nn.init.normal_(self.k_proj_img.weight, std=std)
|
| 104 |
+
nn.init.normal_(self.v_proj_img.weight, std=std)
|
| 105 |
+
nn.init.normal_(self.o_proj.weight, std=std)
|
| 106 |
+
|
| 107 |
+
# FFN
|
| 108 |
+
nn.init.normal_(self.gate_proj.weight, std=std)
|
| 109 |
+
nn.init.normal_(self.up_proj.weight, std=std)
|
| 110 |
+
nn.init.normal_(self.down_proj.weight, std=std)
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def _make_key_bias_from_mask(mask_1d: torch.Tensor, key_len: int) -> torch.Tensor:
|
| 114 |
+
"""
|
| 115 |
+
mask_1d: [B, T_key] with 1=keep, 0=mask
|
| 116 |
+
returns: [B, 1, 1, T_key] float bias with 0 for keep and -inf for mask
|
| 117 |
+
"""
|
| 118 |
+
if mask_1d.dtype != torch.float32:
|
| 119 |
+
mask_f = mask_1d.float()
|
| 120 |
+
else:
|
| 121 |
+
mask_f = mask_1d
|
| 122 |
+
if mask_f.shape[1] != key_len:
|
| 123 |
+
raise ValueError(f"mask_1d width mismatch: got {mask_f.shape[1]} expected {key_len}")
|
| 124 |
+
bias = (1.0 - mask_f) * -1e9
|
| 125 |
+
return bias.view(mask_f.shape[0], 1, 1, key_len)
|
| 126 |
+
|
| 127 |
+
def _split_heads_q(self, x: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
# x: [B, T, Hq*d] -> [B, Hq, T, d]
|
| 129 |
+
B, T, _ = x.shape
|
| 130 |
+
return x.view(B, T, self.num_query_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 131 |
+
|
| 132 |
+
def _split_heads_kv(self, x: torch.Tensor) -> torch.Tensor:
|
| 133 |
+
# x: [B, T, Hkv*d] -> [B, Hkv, T, d]
|
| 134 |
+
B, T, _ = x.shape
|
| 135 |
+
return x.view(B, T, self.num_kv_heads, self.head_dim).transpose(1, 2).contiguous()
|
| 136 |
+
|
| 137 |
+
@staticmethod
|
| 138 |
+
def _merge_heads_q(x: torch.Tensor) -> torch.Tensor:
|
| 139 |
+
# x: [B, Hq, T, d] -> [B, T, Hq*d]
|
| 140 |
+
B, H, T, d = x.shape
|
| 141 |
+
return x.transpose(1, 2).contiguous().view(B, T, H * d)
|
| 142 |
+
|
| 143 |
+
def forward(
|
| 144 |
+
self,
|
| 145 |
+
x_tab: torch.Tensor,
|
| 146 |
+
attention_mask: torch.Tensor,
|
| 147 |
+
vision_features: Optional[torch.Tensor] = None,
|
| 148 |
+
vision_mask: Optional[torch.Tensor] = None,
|
| 149 |
+
) -> torch.Tensor:
|
| 150 |
+
"""
|
| 151 |
+
x_tab: [B, T_tab, tabular_dim]
|
| 152 |
+
attention_mask: [B, T_tab] (1=valid tab token, 0=masked tab token). Does NOT include vision.
|
| 153 |
+
vision_features: None or [B, T_img, vision_dim]
|
| 154 |
+
vision_mask: None or [B, T_img] (1=valid vision token, 0=masked). Required if vision_features is not None.
|
| 155 |
+
returns: updated x_tab [B, T_tab, tabular_dim]
|
| 156 |
+
"""
|
| 157 |
+
if x_tab.dim() != 3:
|
| 158 |
+
raise ValueError(f"x_tab must be [B,T,D], got {tuple(x_tab.shape)}")
|
| 159 |
+
if attention_mask.dim() != 2:
|
| 160 |
+
raise ValueError(f"attention_mask must be [B,T_tab], got {tuple(attention_mask.shape)}")
|
| 161 |
+
|
| 162 |
+
B, T_tab, D = x_tab.shape
|
| 163 |
+
if D != self.tabular_dim:
|
| 164 |
+
raise ValueError(f"tabular_dim mismatch: got {D}, expected {self.tabular_dim}")
|
| 165 |
+
if attention_mask.shape != (B, T_tab):
|
| 166 |
+
raise ValueError("attention_mask shape mismatch with x_tab")
|
| 167 |
+
if attention_mask.device != x_tab.device:
|
| 168 |
+
attention_mask = attention_mask.to(device=x_tab.device)
|
| 169 |
+
|
| 170 |
+
# ---- Attention block (pre-norm)
|
| 171 |
+
h = self.attn_norm(x_tab)
|
| 172 |
+
|
| 173 |
+
q_tab = self.q_proj_tab(h) # [B, T_tab, Hq*d]
|
| 174 |
+
k_tab = self.k_proj_tab(h) # [B, T_tab, Hkv*d]
|
| 175 |
+
v_tab = self.v_proj_tab(h) # [B, T_tab, Hkv*d]
|
| 176 |
+
|
| 177 |
+
q = self._split_heads_q(q_tab) # [B, Hq, T_tab, d]
|
| 178 |
+
k_tab = self._split_heads_kv(k_tab) # [B, Hkv, T_tab, d]
|
| 179 |
+
v_tab = self._split_heads_kv(v_tab) # [B, Hkv, T_tab, d]
|
| 180 |
+
|
| 181 |
+
if vision_features is None:
|
| 182 |
+
# Keys/values = tab only
|
| 183 |
+
k = k_tab
|
| 184 |
+
v = v_tab
|
| 185 |
+
key_mask = attention_mask # [B, T_tab]
|
| 186 |
+
else:
|
| 187 |
+
if vision_features.dim() != 3:
|
| 188 |
+
raise ValueError(f"vision_features must be [B,T_img,Dv], got {tuple(vision_features.shape)}")
|
| 189 |
+
if vision_features.shape[0] != B:
|
| 190 |
+
raise ValueError("vision_features batch mismatch")
|
| 191 |
+
if vision_features.shape[2] != self.vision_dim:
|
| 192 |
+
raise ValueError(f"vision_dim mismatch: got {vision_features.shape[2]}, expected {self.vision_dim}")
|
| 193 |
+
|
| 194 |
+
# Require vision_mask for strict missing handling
|
| 195 |
+
if vision_mask is None:
|
| 196 |
+
raise ValueError("vision_mask must be provided when vision_features is not None")
|
| 197 |
+
if vision_mask.dim() != 2:
|
| 198 |
+
raise ValueError(f"vision_mask must be [B,T_img], got {tuple(vision_mask.shape)}")
|
| 199 |
+
|
| 200 |
+
T_img = vision_features.shape[1]
|
| 201 |
+
if vision_mask.shape != (B, T_img):
|
| 202 |
+
raise ValueError(f"vision_mask shape mismatch: expected {(B, T_img)}, got {tuple(vision_mask.shape)}")
|
| 203 |
+
|
| 204 |
+
# Ensure mask dtype matches attention_mask dtype for concatenation
|
| 205 |
+
if vision_mask.dtype != attention_mask.dtype:
|
| 206 |
+
vision_mask = vision_mask.to(dtype=attention_mask.dtype)
|
| 207 |
+
if vision_mask.device != attention_mask.device:
|
| 208 |
+
vision_mask = vision_mask.to(device=attention_mask.device)
|
| 209 |
+
|
| 210 |
+
param = self.k_proj_img.weight
|
| 211 |
+
vision_features = vision_features.to(device=param.device, dtype=param.dtype)
|
| 212 |
+
k_img = self.k_proj_img(vision_features) # [B, T_img, Hkv*d]
|
| 213 |
+
v_img = self.v_proj_img(vision_features) # [B, T_img, Hkv*d]
|
| 214 |
+
k_img = self._split_heads_kv(k_img) # [B, Hkv, T_img, d]
|
| 215 |
+
v_img = self._split_heads_kv(v_img) # [B, Hkv, T_img, d]
|
| 216 |
+
|
| 217 |
+
k = torch.cat([k_tab, k_img], dim=2) # [B, Hkv, T_tab+T_img, d]
|
| 218 |
+
v = torch.cat([v_tab, v_img], dim=2) # [B, Hkv, T_tab+T_img, d]
|
| 219 |
+
|
| 220 |
+
# STRICT key mask: tab_mask + vision_mask
|
| 221 |
+
key_mask = torch.cat([attention_mask, vision_mask], dim=1) # [B, T_tab+T_img]
|
| 222 |
+
|
| 223 |
+
# Expand KV heads to Q heads (GQA)
|
| 224 |
+
if self.group_size != 1:
|
| 225 |
+
k = k.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d]
|
| 226 |
+
v = v.repeat_interleave(self.group_size, dim=1) # [B, Hq, T_k, d]
|
| 227 |
+
|
| 228 |
+
T_k = k.shape[2]
|
| 229 |
+
key_bias = self._make_key_bias_from_mask(key_mask, key_len=T_k) # [B,1,1,T_k]
|
| 230 |
+
|
| 231 |
+
# Attention scores: [B, Hq, T_tab, T_k]
|
| 232 |
+
scale = 1.0 / math.sqrt(self.head_dim)
|
| 233 |
+
attn_scores = torch.einsum("bhtd,bhkd->bhtk", q, k) * scale
|
| 234 |
+
attn_scores = attn_scores + key_bias # broadcast
|
| 235 |
+
|
| 236 |
+
attn_probs = F.softmax(attn_scores.float(), dim=-1)
|
| 237 |
+
if self.attn_dropout > 0.0 and self.training:
|
| 238 |
+
attn_probs = F.dropout(attn_probs, p=self.attn_dropout)
|
| 239 |
+
attn_probs = attn_probs.to(v.dtype)
|
| 240 |
+
|
| 241 |
+
attn_out = torch.einsum("bhtk,bhkd->bhtd", attn_probs, v) # [B,Hq,T_tab,d]
|
| 242 |
+
attn_out = self._merge_heads_q(attn_out) # [B,T_tab,Hq*d]
|
| 243 |
+
attn_out = self.o_proj(attn_out) # [B,T_tab,tab_dim]
|
| 244 |
+
|
| 245 |
+
# Query-side masking (tab only): prevents masked tab tokens from updating residual path
|
| 246 |
+
attn_out = attn_out * attention_mask.to(attn_out.dtype).unsqueeze(-1)
|
| 247 |
+
|
| 248 |
+
if self.resid_dropout > 0.0 and self.training:
|
| 249 |
+
attn_out = F.dropout(attn_out, p=self.resid_dropout)
|
| 250 |
+
|
| 251 |
+
x = x_tab + attn_out
|
| 252 |
+
|
| 253 |
+
# ---- FFN block (pre-norm)
|
| 254 |
+
h2 = self.ffn_norm(x)
|
| 255 |
+
gate = self.gate_proj(h2)
|
| 256 |
+
up = self.up_proj(h2)
|
| 257 |
+
f = self.act(gate, up)
|
| 258 |
+
f = self.down_proj(f)
|
| 259 |
+
|
| 260 |
+
# Query-side masking (tab only)
|
| 261 |
+
f = f * attention_mask.to(f.dtype).unsqueeze(-1)
|
| 262 |
+
|
| 263 |
+
if self.resid_dropout > 0.0 and self.training:
|
| 264 |
+
f = F.dropout(f, p=self.resid_dropout)
|
| 265 |
+
|
| 266 |
+
x = x + f
|
| 267 |
+
return x
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _count_params(m: nn.Module) -> Tuple[int, int]:
|
| 271 |
+
total = sum(p.numel() for p in m.parameters())
|
| 272 |
+
trainable = sum(p.numel() for p in m.parameters() if p.requires_grad)
|
| 273 |
+
return total, trainable
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def _demo_main():
|
| 277 |
+
import argparse
|
| 278 |
+
|
| 279 |
+
parser = argparse.ArgumentParser()
|
| 280 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 281 |
+
parser.add_argument("--t_tab", type=int, default=126)
|
| 282 |
+
parser.add_argument("--t_img", type=int, default=256)
|
| 283 |
+
parser.add_argument("--tabular_dim", type=int, default=768)
|
| 284 |
+
parser.add_argument("--vision_dim", type=int, default=768)
|
| 285 |
+
parser.add_argument("--num_query_heads", type=int, default=8)
|
| 286 |
+
parser.add_argument("--num_kv_heads", type=int, default=2)
|
| 287 |
+
parser.add_argument("--head_dim", type=int, default=128)
|
| 288 |
+
parser.add_argument("--mlp_ratio", type=float, default=1.5)
|
| 289 |
+
parser.add_argument("--dropout", type=float, default=0.0)
|
| 290 |
+
parser.add_argument("--with_vision", action="store_true")
|
| 291 |
+
parser.add_argument("--dtype", type=str, default="float32", choices=["float16", "bfloat16", "float32"])
|
| 292 |
+
parser.add_argument("--device", type=str, default=None)
|
| 293 |
+
args = parser.parse_args()
|
| 294 |
+
|
| 295 |
+
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 296 |
+
dtype_map = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}
|
| 297 |
+
dtype = dtype_map[args.dtype]
|
| 298 |
+
|
| 299 |
+
layer = TabularImageGQALayer(
|
| 300 |
+
tabular_dim=args.tabular_dim,
|
| 301 |
+
vision_dim=args.vision_dim,
|
| 302 |
+
num_query_heads=args.num_query_heads,
|
| 303 |
+
num_kv_heads=args.num_kv_heads,
|
| 304 |
+
head_dim=args.head_dim,
|
| 305 |
+
mlp_ratio=args.mlp_ratio,
|
| 306 |
+
dropout=args.dropout,
|
| 307 |
+
).to(device=device, dtype=dtype)
|
| 308 |
+
|
| 309 |
+
total, trainable = _count_params(layer)
|
| 310 |
+
print(f"Layer parameters: {total:,} (trainable: {trainable:,})")
|
| 311 |
+
|
| 312 |
+
B = args.batch_size
|
| 313 |
+
T_tab = args.t_tab
|
| 314 |
+
|
| 315 |
+
x_tab = torch.randn(B, T_tab, args.tabular_dim, device=device, dtype=dtype)
|
| 316 |
+
|
| 317 |
+
# Build a typical HF-style 1D attention mask: 1 for valid, 0 for masked/padded.
|
| 318 |
+
# Here we create variable valid lengths.
|
| 319 |
+
lengths = torch.randint(low=max(1, T_tab // 2), high=T_tab + 1, size=(B,), device=device)
|
| 320 |
+
attention_mask = torch.zeros(B, T_tab, device=device, dtype=torch.long)
|
| 321 |
+
for b in range(B):
|
| 322 |
+
attention_mask[b, : int(lengths[b].item())] = 1
|
| 323 |
+
|
| 324 |
+
if args.with_vision:
|
| 325 |
+
vision = torch.randn(B, args.t_img, args.vision_dim, device=device, dtype=dtype)
|
| 326 |
+
|
| 327 |
+
# Example vision mask: first half valid for sample 0, all valid for others
|
| 328 |
+
vision_mask = torch.ones(B, args.t_img, device=device, dtype=torch.long)
|
| 329 |
+
if args.t_img > 0:
|
| 330 |
+
vision_mask[0, args.t_img // 2:] = 0
|
| 331 |
+
else:
|
| 332 |
+
vision = None
|
| 333 |
+
vision_mask = None
|
| 334 |
+
|
| 335 |
+
print("Input x_tab:", tuple(x_tab.shape), x_tab.dtype, x_tab.device)
|
| 336 |
+
print("Input attention_mask:", tuple(attention_mask.shape), attention_mask.dtype, attention_mask.device)
|
| 337 |
+
print("Input vision_features:", None if vision is None else (tuple(vision.shape), vision.dtype, vision.device))
|
| 338 |
+
print("Input vision_mask:",
|
| 339 |
+
None if vision_mask is None else (tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device))
|
| 340 |
+
|
| 341 |
+
with torch.no_grad():
|
| 342 |
+
y = layer(
|
| 343 |
+
x_tab=x_tab,
|
| 344 |
+
attention_mask=attention_mask,
|
| 345 |
+
vision_features=vision,
|
| 346 |
+
vision_mask=vision_mask,
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
print("Output y_tab:", tuple(y.shape), y.dtype, y.device)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
if __name__ == "__main__":
|
| 353 |
+
_demo_main()
|
modelling/loader.py
ADDED
|
@@ -0,0 +1,1025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# loader.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import ast
|
| 5 |
+
from io import BytesIO
|
| 6 |
+
from urllib.parse import urljoin
|
| 7 |
+
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import requests
|
| 10 |
+
import torch
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
from torchvision import transforms
|
| 14 |
+
|
| 15 |
+
from utils import load_json
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CenterSquareCrop:
|
| 19 |
+
"""
|
| 20 |
+
Crop image to a centered square without resizing.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __call__(self, img: Image.Image):
|
| 24 |
+
w, h = img.size
|
| 25 |
+
|
| 26 |
+
if w == h:
|
| 27 |
+
return img
|
| 28 |
+
|
| 29 |
+
if w > h:
|
| 30 |
+
left = (w - h) // 2
|
| 31 |
+
right = left + h
|
| 32 |
+
top = 0
|
| 33 |
+
bottom = h
|
| 34 |
+
else:
|
| 35 |
+
top = (h - w) // 2
|
| 36 |
+
bottom = top + w
|
| 37 |
+
left = 0
|
| 38 |
+
right = w
|
| 39 |
+
return img.crop((left, top, right, bottom))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_image_transform(image_size: int):
|
| 43 |
+
return transforms.Compose([
|
| 44 |
+
CenterSquareCrop(),
|
| 45 |
+
transforms.Resize((image_size, image_size)),
|
| 46 |
+
transforms.ToTensor(),
|
| 47 |
+
])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def join_photo_root(photo_root: str, relative_path: str) -> str:
|
| 51 |
+
"""
|
| 52 |
+
Join photo_root and relative path.
|
| 53 |
+
|
| 54 |
+
Supports:
|
| 55 |
+
- local filesystem roots
|
| 56 |
+
- http / https roots
|
| 57 |
+
"""
|
| 58 |
+
if photo_root.startswith("http://") or photo_root.startswith("https://"): # noqa
|
| 59 |
+
return urljoin(photo_root.rstrip("/") + "/", relative_path)
|
| 60 |
+
|
| 61 |
+
return photo_root.rstrip("/") + "/" + relative_path.lstrip("/")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def parse_numeric_cell(value: str, n_in: int):
|
| 65 |
+
"""
|
| 66 |
+
Convert numeric csv cell to list[float].
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
values, is_valid
|
| 70 |
+
|
| 71 |
+
Data assumption:
|
| 72 |
+
- Empty value is always ""
|
| 73 |
+
- Scalar numeric -> "12.3"
|
| 74 |
+
- Vector numeric -> "[1.2,3.4,5.6]"
|
| 75 |
+
"""
|
| 76 |
+
if value == "":
|
| 77 |
+
return [0.0] * n_in, False
|
| 78 |
+
|
| 79 |
+
if n_in == 1:
|
| 80 |
+
return [float(value)], True
|
| 81 |
+
|
| 82 |
+
vec = ast.literal_eval(value)
|
| 83 |
+
if len(vec) != n_in:
|
| 84 |
+
raise ValueError(f"Numeric vector length mismatch: expected {n_in}, got {len(vec)}")
|
| 85 |
+
return [float(v) for v in vec], True
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class SoilFormerDataset(Dataset):
|
| 89 |
+
|
| 90 |
+
def __init__(
|
| 91 |
+
self,
|
| 92 |
+
csv_path: str,
|
| 93 |
+
photo_map_path: str,
|
| 94 |
+
cat_vocab_path: str,
|
| 95 |
+
numeric_vocab_path: str,
|
| 96 |
+
numeric_stats_path: str,
|
| 97 |
+
photo_root: str,
|
| 98 |
+
image_size: int = 512,
|
| 99 |
+
id_column: str = "id",
|
| 100 |
+
):
|
| 101 |
+
self.df = pd.read_csv(
|
| 102 |
+
csv_path,
|
| 103 |
+
keep_default_na=False,
|
| 104 |
+
na_filter=False,
|
| 105 |
+
low_memory=False,
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self.photo_map = load_json(photo_map_path)
|
| 109 |
+
self.cat_vocab = load_json(cat_vocab_path)
|
| 110 |
+
self.numeric_vocab = load_json(numeric_vocab_path)
|
| 111 |
+
|
| 112 |
+
self.photo_root = photo_root
|
| 113 |
+
self.id_column = id_column
|
| 114 |
+
self.image_size = int(image_size)
|
| 115 |
+
self.image_transform = build_image_transform(self.image_size)
|
| 116 |
+
|
| 117 |
+
# Keep json order exactly
|
| 118 |
+
self.cat_columns = list(self.cat_vocab.keys())
|
| 119 |
+
self.numeric_groups = self.numeric_vocab["groups"]
|
| 120 |
+
self.numeric_stats_df = pd.read_csv(numeric_stats_path)
|
| 121 |
+
self.numeric_stats_index = self.numeric_stats_df.set_index("column")
|
| 122 |
+
|
| 123 |
+
# Numeric mean/std
|
| 124 |
+
self.numeric_stats = {}
|
| 125 |
+
for _, row in self.numeric_stats_df.iterrows():
|
| 126 |
+
col = row["column"]
|
| 127 |
+
mean = float(row["mean"])
|
| 128 |
+
std = float(row["std"])
|
| 129 |
+
if std == 0.0:
|
| 130 |
+
std = 1.0
|
| 131 |
+
self.numeric_stats[col] = (mean, std)
|
| 132 |
+
|
| 133 |
+
# For active masking
|
| 134 |
+
self.cat_mask_local_ids = torch.tensor(
|
| 135 |
+
[int(self.cat_vocab[col]["mask_local_id"]) for col in self.cat_columns],
|
| 136 |
+
dtype=torch.long,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def __len__(self):
|
| 140 |
+
return len(self.df)
|
| 141 |
+
|
| 142 |
+
def load_image(self, path: str):
|
| 143 |
+
if path.startswith("http://") or path.startswith("https://"): # noqa
|
| 144 |
+
resp = requests.get(path, timeout=(3, 10))
|
| 145 |
+
resp.raise_for_status()
|
| 146 |
+
img = Image.open(BytesIO(resp.content)).convert("RGB")
|
| 147 |
+
else:
|
| 148 |
+
img = Image.open(path).convert("RGB")
|
| 149 |
+
|
| 150 |
+
return self.image_transform(img)
|
| 151 |
+
|
| 152 |
+
def __getitem__(self, idx):
|
| 153 |
+
row = self.df.iloc[idx]
|
| 154 |
+
sample_id = row[self.id_column]
|
| 155 |
+
|
| 156 |
+
# -----------------------
|
| 157 |
+
# categorical features
|
| 158 |
+
# -----------------------
|
| 159 |
+
cat_ids = []
|
| 160 |
+
cat_valids = []
|
| 161 |
+
|
| 162 |
+
for col in self.cat_columns:
|
| 163 |
+
spec = self.cat_vocab[col]
|
| 164 |
+
label2id = spec["label2id"]
|
| 165 |
+
mask_id = spec["mask_local_id"]
|
| 166 |
+
|
| 167 |
+
value = row[col]
|
| 168 |
+
|
| 169 |
+
if value == "":
|
| 170 |
+
cat_ids.append(mask_id)
|
| 171 |
+
cat_valids.append(False)
|
| 172 |
+
else:
|
| 173 |
+
if value not in label2id:
|
| 174 |
+
raise KeyError(f"Unknown categorical value: column={col}, value={value!r}")
|
| 175 |
+
cat_ids.append(label2id[value])
|
| 176 |
+
cat_valids.append(True)
|
| 177 |
+
|
| 178 |
+
cat_ids = torch.tensor(cat_ids, dtype=torch.long)
|
| 179 |
+
cat_valids = torch.tensor(cat_valids, dtype=torch.bool)
|
| 180 |
+
|
| 181 |
+
# -----------------------
|
| 182 |
+
# numeric features
|
| 183 |
+
# -----------------------
|
| 184 |
+
numeric_values_by_nin = {}
|
| 185 |
+
numeric_valid_positions_by_nin = {}
|
| 186 |
+
|
| 187 |
+
for group in self.numeric_groups:
|
| 188 |
+
n_in = int(group["n_in"])
|
| 189 |
+
features = group["feature_names"]
|
| 190 |
+
|
| 191 |
+
values = []
|
| 192 |
+
valids = []
|
| 193 |
+
|
| 194 |
+
for feat in features:
|
| 195 |
+
cell = row[feat]
|
| 196 |
+
parsed, is_valid = parse_numeric_cell(cell, n_in)
|
| 197 |
+
if is_valid:
|
| 198 |
+
mean, std = self.numeric_stats[feat]
|
| 199 |
+
parsed = [(v - mean) / std for v in parsed]
|
| 200 |
+
values.append(parsed)
|
| 201 |
+
valids.append(is_valid)
|
| 202 |
+
|
| 203 |
+
numeric_values_by_nin[n_in] = torch.tensor(values, dtype=torch.float32)
|
| 204 |
+
numeric_valid_positions_by_nin[n_in] = torch.tensor(valids, dtype=torch.bool)
|
| 205 |
+
|
| 206 |
+
# -----------------------
|
| 207 |
+
# vision
|
| 208 |
+
# -----------------------
|
| 209 |
+
try:
|
| 210 |
+
relative_path = self.photo_map[sample_id]
|
| 211 |
+
full_path = join_photo_root(self.photo_root, relative_path)
|
| 212 |
+
image = self.load_image(full_path)
|
| 213 |
+
vision_valid = True
|
| 214 |
+
except Exception: # noqa
|
| 215 |
+
image = torch.zeros(3, self.image_size, self.image_size, dtype=torch.float32)
|
| 216 |
+
vision_valid = False
|
| 217 |
+
|
| 218 |
+
vision_valid = torch.tensor(vision_valid, dtype=torch.bool)
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"row_idx": torch.tensor(idx, dtype=torch.long),
|
| 222 |
+
"sample_id": sample_id,
|
| 223 |
+
"cat_local_ids": cat_ids,
|
| 224 |
+
"cat_valid_positions": cat_valids,
|
| 225 |
+
"numeric_values_by_nin": numeric_values_by_nin,
|
| 226 |
+
"numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
|
| 227 |
+
"pixel_values": image,
|
| 228 |
+
"vision_valid_positions": vision_valid,
|
| 229 |
+
}
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def collate_fn(batch):
|
| 233 |
+
cat_ids = torch.stack([b["cat_local_ids"] for b in batch], dim=0)
|
| 234 |
+
cat_valids = torch.stack([b["cat_valid_positions"] for b in batch], dim=0)
|
| 235 |
+
|
| 236 |
+
group_keys = list(batch[0]["numeric_values_by_nin"].keys())
|
| 237 |
+
|
| 238 |
+
numeric_values_by_nin = {}
|
| 239 |
+
numeric_valid_positions_by_nin = {}
|
| 240 |
+
|
| 241 |
+
for k in group_keys:
|
| 242 |
+
numeric_values_by_nin[k] = torch.stack(
|
| 243 |
+
[b["numeric_values_by_nin"][k] for b in batch],
|
| 244 |
+
dim=0,
|
| 245 |
+
)
|
| 246 |
+
numeric_valid_positions_by_nin[k] = torch.stack(
|
| 247 |
+
[b["numeric_valid_positions_by_nin"][k] for b in batch],
|
| 248 |
+
dim=0,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
pixel_values = torch.stack([b["pixel_values"] for b in batch], dim=0)
|
| 252 |
+
vision_valid_positions = torch.stack([b["vision_valid_positions"] for b in batch], dim=0)
|
| 253 |
+
row_idx = torch.stack([b["row_idx"] for b in batch], dim=0)
|
| 254 |
+
sample_ids = [b["sample_id"] for b in batch]
|
| 255 |
+
|
| 256 |
+
return {
|
| 257 |
+
"row_idx": row_idx,
|
| 258 |
+
"sample_id": sample_ids,
|
| 259 |
+
"cat_local_ids": cat_ids,
|
| 260 |
+
"numeric_values_by_nin": numeric_values_by_nin,
|
| 261 |
+
"cat_valid_positions": cat_valids,
|
| 262 |
+
"numeric_valid_positions_by_nin": numeric_valid_positions_by_nin,
|
| 263 |
+
"pixel_values": pixel_values,
|
| 264 |
+
"vision_valid_positions": vision_valid_positions,
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
def perform_active_mask(self, batch, cat_ratio=0.15, num_ratio=0.15, seed=None):
|
| 268 |
+
"""
|
| 269 |
+
Apply active masking to categorical and numeric inputs.
|
| 270 |
+
|
| 271 |
+
Conventions
|
| 272 |
+
-----------
|
| 273 |
+
Input batch must contain:
|
| 274 |
+
- cat_local_ids: [B, M] LongTensor
|
| 275 |
+
- cat_valid_positions: [B, M] Bool/0-1 tensor
|
| 276 |
+
- numeric_values_by_nin: Dict[int, Tensor[B, V, n_in]]
|
| 277 |
+
- numeric_valid_positions_by_nin: Dict[int, Tensor[B, V]]
|
| 278 |
+
|
| 279 |
+
Output batch will additionally contain:
|
| 280 |
+
- original_cat_local_ids
|
| 281 |
+
- original_cat_valid_positions
|
| 282 |
+
- original_numeric_values_by_nin
|
| 283 |
+
- original_numeric_valid_positions_by_nin
|
| 284 |
+
|
| 285 |
+
- masked_cat_local_ids
|
| 286 |
+
- masked_cat_valid_positions
|
| 287 |
+
- masked_numeric_values_by_nin
|
| 288 |
+
- masked_numeric_valid_positions_by_nin
|
| 289 |
+
|
| 290 |
+
- cat_loss_mask: [B, M] BoolTensor
|
| 291 |
+
- numeric_loss_mask_by_nin: Dict[int, BoolTensor[B, V]]
|
| 292 |
+
|
| 293 |
+
Semantics
|
| 294 |
+
---------
|
| 295 |
+
- Only originally valid positions can be actively masked.
|
| 296 |
+
- Masked categorical positions:
|
| 297 |
+
local_id -> self.cat_mask_local_ids[col]
|
| 298 |
+
valid -> False
|
| 299 |
+
- Masked numeric positions:
|
| 300 |
+
values -> 0
|
| 301 |
+
valid -> False
|
| 302 |
+
- original_* fields always preserve the unmodified input batch content.
|
| 303 |
+
"""
|
| 304 |
+
# --------------------------------------------------
|
| 305 |
+
# Validate ratios
|
| 306 |
+
# --------------------------------------------------
|
| 307 |
+
if not (0.0 <= cat_ratio <= 1.0):
|
| 308 |
+
raise ValueError(f"cat_ratio must be in [0, 1], got {cat_ratio}")
|
| 309 |
+
if not (0.0 <= num_ratio <= 1.0):
|
| 310 |
+
raise ValueError(f"num_ratio must be in [0, 1], got {num_ratio}")
|
| 311 |
+
|
| 312 |
+
# --------------------------------------------------
|
| 313 |
+
# Validate required keys
|
| 314 |
+
# --------------------------------------------------
|
| 315 |
+
required_keys = [
|
| 316 |
+
"cat_local_ids",
|
| 317 |
+
"cat_valid_positions",
|
| 318 |
+
"numeric_values_by_nin",
|
| 319 |
+
"numeric_valid_positions_by_nin",
|
| 320 |
+
]
|
| 321 |
+
for k in required_keys:
|
| 322 |
+
if k not in batch:
|
| 323 |
+
raise KeyError(f"Missing key in batch: {k}")
|
| 324 |
+
|
| 325 |
+
cat_local_ids = batch["cat_local_ids"]
|
| 326 |
+
cat_valid_positions = batch["cat_valid_positions"]
|
| 327 |
+
numeric_values_by_nin = batch["numeric_values_by_nin"]
|
| 328 |
+
numeric_valid_positions_by_nin = batch["numeric_valid_positions_by_nin"]
|
| 329 |
+
|
| 330 |
+
if cat_local_ids.dim() != 2:
|
| 331 |
+
raise ValueError(f"cat_local_ids must be [B, M], got {tuple(cat_local_ids.shape)}")
|
| 332 |
+
if cat_valid_positions.shape != cat_local_ids.shape:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
f"cat_valid_positions must match cat_local_ids shape, got "
|
| 335 |
+
f"{tuple(cat_valid_positions.shape)} vs {tuple(cat_local_ids.shape)}"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if not isinstance(numeric_values_by_nin, dict):
|
| 339 |
+
raise ValueError("numeric_values_by_nin must be a dict")
|
| 340 |
+
if not isinstance(numeric_valid_positions_by_nin, dict):
|
| 341 |
+
raise ValueError("numeric_valid_positions_by_nin must be a dict")
|
| 342 |
+
|
| 343 |
+
B, M = cat_local_ids.shape
|
| 344 |
+
device = cat_local_ids.device
|
| 345 |
+
|
| 346 |
+
if self.cat_mask_local_ids.dim() != 1 or self.cat_mask_local_ids.numel() != M:
|
| 347 |
+
raise ValueError(
|
| 348 |
+
f"self.cat_mask_local_ids must be [M] with M={M}, got {tuple(self.cat_mask_local_ids.shape)}"
|
| 349 |
+
)
|
| 350 |
+
cat_mask_local_ids = self.cat_mask_local_ids.to(device=device, dtype=cat_local_ids.dtype)
|
| 351 |
+
|
| 352 |
+
# --------------------------------------------------
|
| 353 |
+
# Random generator
|
| 354 |
+
# --------------------------------------------------
|
| 355 |
+
if device.type == "cuda":
|
| 356 |
+
generator = torch.Generator(device=device)
|
| 357 |
+
else:
|
| 358 |
+
generator = torch.Generator()
|
| 359 |
+
|
| 360 |
+
if seed is not None:
|
| 361 |
+
generator.manual_seed(seed)
|
| 362 |
+
|
| 363 |
+
# --------------------------------------------------
|
| 364 |
+
# Start from shallow copy only
|
| 365 |
+
# --------------------------------------------------
|
| 366 |
+
masked_batch = dict(batch)
|
| 367 |
+
|
| 368 |
+
# Preserve original aliases (do NOT deepcopy)
|
| 369 |
+
masked_batch["original_cat_local_ids"] = batch["cat_local_ids"]
|
| 370 |
+
masked_batch["original_cat_valid_positions"] = batch["cat_valid_positions"]
|
| 371 |
+
masked_batch["original_numeric_values_by_nin"] = batch["numeric_values_by_nin"]
|
| 372 |
+
masked_batch["original_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"]
|
| 373 |
+
|
| 374 |
+
# --------------------------------------------------
|
| 375 |
+
# Fast path: no active masking at all
|
| 376 |
+
# --------------------------------------------------
|
| 377 |
+
if cat_ratio == 0.0 and num_ratio == 0.0:
|
| 378 |
+
masked_batch["masked_cat_local_ids"] = batch["cat_local_ids"]
|
| 379 |
+
masked_batch["masked_cat_valid_positions"] = batch["cat_valid_positions"]
|
| 380 |
+
|
| 381 |
+
masked_batch["masked_numeric_values_by_nin"] = batch["numeric_values_by_nin"]
|
| 382 |
+
masked_batch["masked_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"]
|
| 383 |
+
|
| 384 |
+
masked_batch["cat_loss_mask"] = torch.zeros(
|
| 385 |
+
(B, M), dtype=torch.bool, device=device
|
| 386 |
+
)
|
| 387 |
+
masked_batch["numeric_loss_mask_by_nin"] = {
|
| 388 |
+
n_in: torch.zeros_like(valid_positions, dtype=torch.bool)
|
| 389 |
+
for n_in, valid_positions in numeric_valid_positions_by_nin.items()
|
| 390 |
+
}
|
| 391 |
+
return masked_batch
|
| 392 |
+
|
| 393 |
+
# --------------------------------------------------
|
| 394 |
+
# Categorical masking
|
| 395 |
+
# --------------------------------------------------
|
| 396 |
+
original_cat_valid_positions = cat_valid_positions.bool()
|
| 397 |
+
|
| 398 |
+
masked_cat_local_ids = cat_local_ids.clone()
|
| 399 |
+
masked_cat_valid_positions = original_cat_valid_positions.clone()
|
| 400 |
+
cat_loss_mask = torch.zeros((B, M), dtype=torch.bool, device=device)
|
| 401 |
+
|
| 402 |
+
if cat_ratio > 0.0:
|
| 403 |
+
for b in range(B):
|
| 404 |
+
valid_idx = torch.nonzero(original_cat_valid_positions[b], as_tuple=False).squeeze(1)
|
| 405 |
+
n_valid = valid_idx.numel()
|
| 406 |
+
if n_valid == 0:
|
| 407 |
+
continue
|
| 408 |
+
|
| 409 |
+
k = int(round(n_valid * cat_ratio))
|
| 410 |
+
if k <= 0:
|
| 411 |
+
continue
|
| 412 |
+
if k > n_valid:
|
| 413 |
+
k = n_valid
|
| 414 |
+
|
| 415 |
+
perm = valid_idx[
|
| 416 |
+
torch.randperm(n_valid, generator=generator, device=device)[:k]
|
| 417 |
+
]
|
| 418 |
+
cat_loss_mask[b, perm] = True
|
| 419 |
+
|
| 420 |
+
expanded_cat_mask_ids = cat_mask_local_ids.view(1, M).expand(B, M)
|
| 421 |
+
masked_cat_local_ids[cat_loss_mask] = expanded_cat_mask_ids[cat_loss_mask]
|
| 422 |
+
masked_cat_valid_positions = masked_cat_valid_positions & (~cat_loss_mask)
|
| 423 |
+
|
| 424 |
+
masked_batch["masked_cat_local_ids"] = masked_cat_local_ids
|
| 425 |
+
masked_batch["masked_cat_valid_positions"] = masked_cat_valid_positions
|
| 426 |
+
masked_batch["cat_loss_mask"] = cat_loss_mask
|
| 427 |
+
|
| 428 |
+
# --------------------------------------------------
|
| 429 |
+
# Numeric masking
|
| 430 |
+
# --------------------------------------------------
|
| 431 |
+
masked_numeric_values_by_nin = {}
|
| 432 |
+
masked_numeric_valid_positions_by_nin = {}
|
| 433 |
+
numeric_loss_mask_by_nin = {}
|
| 434 |
+
|
| 435 |
+
# keep deterministic ordering if caller passed mixed int-like keys
|
| 436 |
+
for n_in in sorted(numeric_values_by_nin.keys(), key=int):
|
| 437 |
+
values = numeric_values_by_nin[n_in]
|
| 438 |
+
if n_in not in numeric_valid_positions_by_nin:
|
| 439 |
+
raise KeyError(f"Missing numeric_valid_positions_by_nin[{n_in}]")
|
| 440 |
+
|
| 441 |
+
valid_positions = numeric_valid_positions_by_nin[n_in]
|
| 442 |
+
|
| 443 |
+
if values.dim() != 3:
|
| 444 |
+
raise ValueError(
|
| 445 |
+
f"numeric_values_by_nin[{n_in}] must be [B, V, n_in], got {tuple(values.shape)}"
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
Bn, V, Nin = values.shape
|
| 449 |
+
if Bn != B:
|
| 450 |
+
raise ValueError(
|
| 451 |
+
f"numeric_values_by_nin[{n_in}] batch mismatch: got {Bn}, expected {B}"
|
| 452 |
+
)
|
| 453 |
+
if int(Nin) != int(n_in):
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"numeric_values_by_nin[{n_in}] last dim mismatch: got {Nin}, expected {n_in}"
|
| 456 |
+
)
|
| 457 |
+
if valid_positions.shape != (B, V):
|
| 458 |
+
raise ValueError(
|
| 459 |
+
f"numeric_valid_positions_by_nin[{n_in}] must be [B,V]=({B},{V}), "
|
| 460 |
+
f"got {tuple(valid_positions.shape)}"
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
original_valid = valid_positions.bool()
|
| 464 |
+
|
| 465 |
+
# IMPORTANT: clone before modifying
|
| 466 |
+
masked_values = values.clone()
|
| 467 |
+
masked_valid_positions = original_valid.clone()
|
| 468 |
+
num_loss_mask = torch.zeros((B, V), dtype=torch.bool, device=values.device)
|
| 469 |
+
|
| 470 |
+
if num_ratio > 0.0:
|
| 471 |
+
for b in range(B):
|
| 472 |
+
valid_idx = torch.nonzero(original_valid[b], as_tuple=False).squeeze(1)
|
| 473 |
+
n_valid = valid_idx.numel()
|
| 474 |
+
if n_valid == 0:
|
| 475 |
+
continue
|
| 476 |
+
|
| 477 |
+
k = int(round(n_valid * num_ratio))
|
| 478 |
+
if k <= 0:
|
| 479 |
+
continue
|
| 480 |
+
if k > n_valid:
|
| 481 |
+
k = n_valid
|
| 482 |
+
|
| 483 |
+
perm = valid_idx[
|
| 484 |
+
torch.randperm(n_valid, generator=generator, device=values.device)[:k]
|
| 485 |
+
]
|
| 486 |
+
num_loss_mask[b, perm] = True
|
| 487 |
+
|
| 488 |
+
# masked numeric columns become zero and invalid
|
| 489 |
+
masked_values[num_loss_mask] = 0.0
|
| 490 |
+
masked_valid_positions = masked_valid_positions & (~num_loss_mask)
|
| 491 |
+
|
| 492 |
+
masked_numeric_values_by_nin[n_in] = masked_values
|
| 493 |
+
masked_numeric_valid_positions_by_nin[n_in] = masked_valid_positions
|
| 494 |
+
numeric_loss_mask_by_nin[n_in] = num_loss_mask
|
| 495 |
+
|
| 496 |
+
masked_batch["masked_numeric_values_by_nin"] = masked_numeric_values_by_nin
|
| 497 |
+
masked_batch["masked_numeric_valid_positions_by_nin"] = masked_numeric_valid_positions_by_nin
|
| 498 |
+
masked_batch["numeric_loss_mask_by_nin"] = numeric_loss_mask_by_nin
|
| 499 |
+
|
| 500 |
+
return masked_batch
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
def perform_active_mask_single(self, batch, feature_name, assert_not_missing=True):
|
| 504 |
+
"""
|
| 505 |
+
Actively mask exactly one feature specified by feature_name.
|
| 506 |
+
|
| 507 |
+
Parameters
|
| 508 |
+
----------
|
| 509 |
+
batch : dict
|
| 510 |
+
Same input convention as perform_active_mask(...).
|
| 511 |
+
feature_name : str
|
| 512 |
+
Full feature name. Can be either categorical or numeric.
|
| 513 |
+
assert_not_missing : bool
|
| 514 |
+
If True, require the target feature to be originally valid for all samples
|
| 515 |
+
in the batch. Otherwise raise ValueError.
|
| 516 |
+
If False, only originally valid positions are masked; naturally missing
|
| 517 |
+
positions remain missing and are not included in the loss mask.
|
| 518 |
+
|
| 519 |
+
Returns
|
| 520 |
+
-------
|
| 521 |
+
masked_batch : dict
|
| 522 |
+
Same output convention as perform_active_mask(...), except that exactly
|
| 523 |
+
one feature is actively masked.
|
| 524 |
+
"""
|
| 525 |
+
|
| 526 |
+
# --------------------------------------------------
|
| 527 |
+
# Validate required keys
|
| 528 |
+
# --------------------------------------------------
|
| 529 |
+
required_keys = [
|
| 530 |
+
"cat_local_ids",
|
| 531 |
+
"cat_valid_positions",
|
| 532 |
+
"numeric_values_by_nin",
|
| 533 |
+
"numeric_valid_positions_by_nin",
|
| 534 |
+
]
|
| 535 |
+
for k in required_keys:
|
| 536 |
+
if k not in batch:
|
| 537 |
+
raise KeyError(f"Missing key in batch: {k}")
|
| 538 |
+
|
| 539 |
+
cat_local_ids = batch["cat_local_ids"]
|
| 540 |
+
cat_valid_positions = batch["cat_valid_positions"]
|
| 541 |
+
numeric_values_by_nin = batch["numeric_values_by_nin"]
|
| 542 |
+
numeric_valid_positions_by_nin = batch["numeric_valid_positions_by_nin"]
|
| 543 |
+
|
| 544 |
+
if cat_local_ids.dim() != 2:
|
| 545 |
+
raise ValueError(f"cat_local_ids must be [B, M], got {tuple(cat_local_ids.shape)}")
|
| 546 |
+
if cat_valid_positions.shape != cat_local_ids.shape:
|
| 547 |
+
raise ValueError(
|
| 548 |
+
f"cat_valid_positions must match cat_local_ids shape, got "
|
| 549 |
+
f"{tuple(cat_valid_positions.shape)} vs {tuple(cat_local_ids.shape)}"
|
| 550 |
+
)
|
| 551 |
+
|
| 552 |
+
if not isinstance(numeric_values_by_nin, dict):
|
| 553 |
+
raise ValueError("numeric_values_by_nin must be a dict")
|
| 554 |
+
if not isinstance(numeric_valid_positions_by_nin, dict):
|
| 555 |
+
raise ValueError("numeric_valid_positions_by_nin must be a dict")
|
| 556 |
+
|
| 557 |
+
B, M = cat_local_ids.shape
|
| 558 |
+
device = cat_local_ids.device
|
| 559 |
+
|
| 560 |
+
if self.cat_mask_local_ids.dim() != 1 or self.cat_mask_local_ids.numel() != M:
|
| 561 |
+
raise ValueError(
|
| 562 |
+
f"self.cat_mask_local_ids must be [M] with M={M}, got {tuple(self.cat_mask_local_ids.shape)}"
|
| 563 |
+
)
|
| 564 |
+
cat_mask_local_ids = self.cat_mask_local_ids.to(device=device, dtype=cat_local_ids.dtype)
|
| 565 |
+
|
| 566 |
+
# --------------------------------------------------
|
| 567 |
+
# Resolve feature_name -> categorical col or numeric (n_in, v_idx)
|
| 568 |
+
# --------------------------------------------------
|
| 569 |
+
# Assumptions:
|
| 570 |
+
# - self.cat_vocab is the categorical vocab dict keyed by full feature name
|
| 571 |
+
# - self.numeric_vocab contains:
|
| 572 |
+
# numeric_vocab["ordered_feature_names"]
|
| 573 |
+
# numeric_vocab["features"][name]["n_in"]
|
| 574 |
+
# numeric_vocab["features"][name]["col_id"]
|
| 575 |
+
#
|
| 576 |
+
# If your actual attribute names differ, only this block needs adaptation.
|
| 577 |
+
is_cat = False
|
| 578 |
+
is_num = False
|
| 579 |
+
cat_col = None
|
| 580 |
+
num_n_in = None
|
| 581 |
+
num_v_idx = None
|
| 582 |
+
|
| 583 |
+
# categorical
|
| 584 |
+
if hasattr(self, "cat_vocab") and feature_name in self.cat_vocab:
|
| 585 |
+
is_cat = True
|
| 586 |
+
cat_col = int(self.cat_vocab[feature_name]["col_id"])
|
| 587 |
+
|
| 588 |
+
# numeric
|
| 589 |
+
if hasattr(self, "numeric_vocab"):
|
| 590 |
+
num_features = self.numeric_vocab.get("features", {})
|
| 591 |
+
if feature_name in num_features:
|
| 592 |
+
is_num = True
|
| 593 |
+
meta = num_features[feature_name]
|
| 594 |
+
num_n_in = int(meta["n_in"])
|
| 595 |
+
num_v_idx = int(meta["col_id"])
|
| 596 |
+
|
| 597 |
+
if is_cat and is_num:
|
| 598 |
+
raise ValueError(f"Feature name appears in both categorical and numeric vocab: {feature_name}")
|
| 599 |
+
if not is_cat and not is_num:
|
| 600 |
+
raise KeyError(f"Unknown feature_name: {feature_name}")
|
| 601 |
+
|
| 602 |
+
# --------------------------------------------------
|
| 603 |
+
# Start from shallow copy only
|
| 604 |
+
# --------------------------------------------------
|
| 605 |
+
masked_batch = dict(batch)
|
| 606 |
+
|
| 607 |
+
# Preserve original aliases (do NOT deepcopy)
|
| 608 |
+
masked_batch["original_cat_local_ids"] = batch["cat_local_ids"]
|
| 609 |
+
masked_batch["original_cat_valid_positions"] = batch["cat_valid_positions"]
|
| 610 |
+
masked_batch["original_numeric_values_by_nin"] = batch["numeric_values_by_nin"]
|
| 611 |
+
masked_batch["original_numeric_valid_positions_by_nin"] = batch["numeric_valid_positions_by_nin"]
|
| 612 |
+
|
| 613 |
+
# --------------------------------------------------
|
| 614 |
+
# Default: no masking anywhere
|
| 615 |
+
# --------------------------------------------------
|
| 616 |
+
masked_cat_local_ids = batch["cat_local_ids"].clone()
|
| 617 |
+
masked_cat_valid_positions = batch["cat_valid_positions"].bool().clone()
|
| 618 |
+
cat_loss_mask = torch.zeros((B, M), dtype=torch.bool, device=device)
|
| 619 |
+
|
| 620 |
+
masked_numeric_values_by_nin = {}
|
| 621 |
+
masked_numeric_valid_positions_by_nin = {}
|
| 622 |
+
numeric_loss_mask_by_nin = {}
|
| 623 |
+
|
| 624 |
+
for n_in in sorted(numeric_values_by_nin.keys(), key=int):
|
| 625 |
+
values = numeric_values_by_nin[n_in]
|
| 626 |
+
if n_in not in numeric_valid_positions_by_nin:
|
| 627 |
+
raise KeyError(f"Missing numeric_valid_positions_by_nin[{n_in}]")
|
| 628 |
+
|
| 629 |
+
valid_positions = numeric_valid_positions_by_nin[n_in]
|
| 630 |
+
|
| 631 |
+
if values.dim() != 3:
|
| 632 |
+
raise ValueError(
|
| 633 |
+
f"numeric_values_by_nin[{n_in}] must be [B, V, n_in], got {tuple(values.shape)}"
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
Bn, V, Nin = values.shape
|
| 637 |
+
if Bn != B:
|
| 638 |
+
raise ValueError(
|
| 639 |
+
f"numeric_values_by_nin[{n_in}] batch mismatch: got {Bn}, expected {B}"
|
| 640 |
+
)
|
| 641 |
+
if int(Nin) != int(n_in):
|
| 642 |
+
raise ValueError(
|
| 643 |
+
f"numeric_values_by_nin[{n_in}] last dim mismatch: got {Nin}, expected {n_in}"
|
| 644 |
+
)
|
| 645 |
+
if valid_positions.shape != (B, V):
|
| 646 |
+
raise ValueError(
|
| 647 |
+
f"numeric_valid_positions_by_nin[{n_in}] must be [B,V]=({B},{V}), "
|
| 648 |
+
f"got {tuple(valid_positions.shape)}"
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
masked_numeric_values_by_nin[n_in] = values.clone()
|
| 652 |
+
masked_numeric_valid_positions_by_nin[n_in] = valid_positions.bool().clone()
|
| 653 |
+
numeric_loss_mask_by_nin[n_in] = torch.zeros((B, V), dtype=torch.bool, device=values.device)
|
| 654 |
+
|
| 655 |
+
# --------------------------------------------------
|
| 656 |
+
# Apply single-feature masking
|
| 657 |
+
# --------------------------------------------------
|
| 658 |
+
if is_cat:
|
| 659 |
+
original_valid = cat_valid_positions[:, cat_col].bool() # [B]
|
| 660 |
+
|
| 661 |
+
if assert_not_missing and not bool(original_valid.all().item()):
|
| 662 |
+
n_bad = int((~original_valid).sum().item())
|
| 663 |
+
raise ValueError(
|
| 664 |
+
f"Categorical feature '{feature_name}' has {n_bad} naturally missing samples in batch"
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# only originally valid positions are actively masked
|
| 668 |
+
cat_loss_mask[:, cat_col] = original_valid
|
| 669 |
+
|
| 670 |
+
masked_cat_local_ids[cat_loss_mask] = cat_mask_local_ids.view(1, M).expand(B, M)[cat_loss_mask]
|
| 671 |
+
masked_cat_valid_positions = masked_cat_valid_positions & (~cat_loss_mask)
|
| 672 |
+
|
| 673 |
+
else:
|
| 674 |
+
if num_n_in not in masked_numeric_values_by_nin:
|
| 675 |
+
raise KeyError(f"numeric_values_by_nin does not contain n_in={num_n_in} for {feature_name}")
|
| 676 |
+
|
| 677 |
+
values = masked_numeric_values_by_nin[num_n_in]
|
| 678 |
+
valid_positions = masked_numeric_valid_positions_by_nin[num_n_in]
|
| 679 |
+
num_loss_mask = numeric_loss_mask_by_nin[num_n_in]
|
| 680 |
+
|
| 681 |
+
if num_v_idx >= values.shape[1]:
|
| 682 |
+
raise IndexError(
|
| 683 |
+
f"Numeric feature '{feature_name}' resolved to v_idx={num_v_idx}, "
|
| 684 |
+
f"but numeric_values_by_nin[{num_n_in}] has V={values.shape[1]}"
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
original_valid = valid_positions[:, num_v_idx].bool() # [B]
|
| 688 |
+
|
| 689 |
+
if assert_not_missing and not bool(original_valid.all().item()):
|
| 690 |
+
n_bad = int((~original_valid).sum().item())
|
| 691 |
+
raise ValueError(
|
| 692 |
+
f"Numeric feature '{feature_name}' has {n_bad} naturally missing samples in batch"
|
| 693 |
+
)
|
| 694 |
+
|
| 695 |
+
# only originally valid positions are actively masked
|
| 696 |
+
num_loss_mask[:, num_v_idx] = original_valid
|
| 697 |
+
|
| 698 |
+
values[num_loss_mask] = 0.0
|
| 699 |
+
valid_positions[:] = valid_positions & (~num_loss_mask)
|
| 700 |
+
|
| 701 |
+
# --------------------------------------------------
|
| 702 |
+
# Finalize outputs
|
| 703 |
+
# --------------------------------------------------
|
| 704 |
+
masked_batch["masked_cat_local_ids"] = masked_cat_local_ids
|
| 705 |
+
masked_batch["masked_cat_valid_positions"] = masked_cat_valid_positions
|
| 706 |
+
masked_batch["cat_loss_mask"] = cat_loss_mask
|
| 707 |
+
|
| 708 |
+
masked_batch["masked_numeric_values_by_nin"] = masked_numeric_values_by_nin
|
| 709 |
+
masked_batch["masked_numeric_valid_positions_by_nin"] = masked_numeric_valid_positions_by_nin
|
| 710 |
+
masked_batch["numeric_loss_mask_by_nin"] = numeric_loss_mask_by_nin
|
| 711 |
+
|
| 712 |
+
return masked_batch
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def build_train_eval_dataloaders(
|
| 716 |
+
dataset,
|
| 717 |
+
train_ratio=0.8,
|
| 718 |
+
seed=42,
|
| 719 |
+
batch_size=32,
|
| 720 |
+
):
|
| 721 |
+
n = len(dataset)
|
| 722 |
+
|
| 723 |
+
n_train = int(n * train_ratio)
|
| 724 |
+
n_eval = n - n_train
|
| 725 |
+
|
| 726 |
+
split_generator = torch.Generator().manual_seed(seed)
|
| 727 |
+
|
| 728 |
+
train_ds, eval_ds = torch.utils.data.random_split(
|
| 729 |
+
dataset,
|
| 730 |
+
[n_train, n_eval],
|
| 731 |
+
generator=split_generator
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
train_generator = torch.Generator()
|
| 735 |
+
|
| 736 |
+
train_loader = DataLoader(
|
| 737 |
+
train_ds,
|
| 738 |
+
batch_size=batch_size,
|
| 739 |
+
shuffle=True,
|
| 740 |
+
collate_fn=dataset.collate_fn,
|
| 741 |
+
generator=train_generator,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
eval_loader = DataLoader(
|
| 745 |
+
eval_ds,
|
| 746 |
+
batch_size=batch_size,
|
| 747 |
+
shuffle=False,
|
| 748 |
+
collate_fn=dataset.collate_fn,
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
return train_loader, eval_loader, train_generator
|
| 752 |
+
|
| 753 |
+
|
| 754 |
+
def debug_print_first_sample(dataset, batch, batch_pos=0):
|
| 755 |
+
"""
|
| 756 |
+
Inspect one sample in a batch.
|
| 757 |
+
|
| 758 |
+
This debug function checks masked_* fields against the original csv row.
|
| 759 |
+
Positions in loss_mask are allowed to mismatch.
|
| 760 |
+
|
| 761 |
+
Args:
|
| 762 |
+
dataset: SoilFormerDataset
|
| 763 |
+
batch: collated + optionally masked batch
|
| 764 |
+
batch_pos: index inside the batch (not dataset row index)
|
| 765 |
+
"""
|
| 766 |
+
import math
|
| 767 |
+
|
| 768 |
+
def numeric_list_close(a, b, atol=1e-6, rtol=1e-5):
|
| 769 |
+
if len(a) != len(b):
|
| 770 |
+
return False
|
| 771 |
+
for x, y in zip(a, b):
|
| 772 |
+
if not math.isclose(float(x), float(y), rel_tol=rtol, abs_tol=atol):
|
| 773 |
+
return False
|
| 774 |
+
return True
|
| 775 |
+
|
| 776 |
+
def normalize_numeric_list(feat_name, vals, is_valid):
|
| 777 |
+
if not is_valid:
|
| 778 |
+
return [0.0] * len(vals)
|
| 779 |
+
|
| 780 |
+
stat_row = dataset.numeric_stats_index.loc[feat_name]
|
| 781 |
+
mean = float(stat_row["mean"])
|
| 782 |
+
std = float(stat_row["std"])
|
| 783 |
+
if std == 0.0:
|
| 784 |
+
std = 1.0
|
| 785 |
+
|
| 786 |
+
return [(float(v) - mean) / std for v in vals]
|
| 787 |
+
|
| 788 |
+
if "row_idx" not in batch:
|
| 789 |
+
raise KeyError("batch must contain 'row_idx' for debug_print_first_sample")
|
| 790 |
+
if "sample_id" not in batch:
|
| 791 |
+
raise KeyError("batch must contain 'sample_id' for debug_print_first_sample")
|
| 792 |
+
|
| 793 |
+
row_idx = int(batch["row_idx"][batch_pos].item())
|
| 794 |
+
row = dataset.df.iloc[row_idx]
|
| 795 |
+
sample_id = batch["sample_id"][batch_pos]
|
| 796 |
+
|
| 797 |
+
print("\n====================================================")
|
| 798 |
+
print("DEBUG SAMPLE")
|
| 799 |
+
print("====================================================")
|
| 800 |
+
print("batch_pos :", batch_pos)
|
| 801 |
+
print("row_idx :", row_idx)
|
| 802 |
+
print("sample_id :", sample_id)
|
| 803 |
+
|
| 804 |
+
# ====================================================
|
| 805 |
+
# categorical
|
| 806 |
+
# ====================================================
|
| 807 |
+
print("\n[CATEGORICAL FEATURES]")
|
| 808 |
+
|
| 809 |
+
cat_ids = batch["masked_cat_local_ids"][batch_pos]
|
| 810 |
+
cat_valids = batch["masked_cat_valid_positions"][batch_pos]
|
| 811 |
+
cat_loss_mask = batch.get("cat_loss_mask", None)
|
| 812 |
+
if cat_loss_mask is not None:
|
| 813 |
+
cat_loss_mask = cat_loss_mask[batch_pos]
|
| 814 |
+
|
| 815 |
+
for i, col in enumerate(dataset.cat_columns):
|
| 816 |
+
raw = row[col]
|
| 817 |
+
raw_str = str(raw)
|
| 818 |
+
|
| 819 |
+
got_id = int(cat_ids[i].item())
|
| 820 |
+
got_valid = bool(cat_valids[i].item())
|
| 821 |
+
|
| 822 |
+
spec = dataset.cat_vocab[col]
|
| 823 |
+
label2id = spec["label2id"]
|
| 824 |
+
mask_id = int(spec["mask_local_id"])
|
| 825 |
+
|
| 826 |
+
if raw == "":
|
| 827 |
+
expected_id = mask_id
|
| 828 |
+
expected_valid = False
|
| 829 |
+
else:
|
| 830 |
+
expected_id = int(label2id[raw])
|
| 831 |
+
expected_valid = True
|
| 832 |
+
|
| 833 |
+
is_loss_position = False
|
| 834 |
+
if cat_loss_mask is not None:
|
| 835 |
+
is_loss_position = bool(cat_loss_mask[i].item())
|
| 836 |
+
|
| 837 |
+
if is_loss_position:
|
| 838 |
+
ok = True
|
| 839 |
+
else:
|
| 840 |
+
ok = (got_id == expected_id) and (got_valid == expected_valid)
|
| 841 |
+
|
| 842 |
+
print(
|
| 843 |
+
f"{i:03d} | {col} | "
|
| 844 |
+
f"raw={raw_str:<60} | "
|
| 845 |
+
f"id={got_id:<6} | expected={expected_id:<6} | "
|
| 846 |
+
f"valid={got_valid} | exp_valid={expected_valid} | "
|
| 847 |
+
f"loss_mask={is_loss_position} | ok={ok}"
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
if not ok:
|
| 851 |
+
raise AssertionError(
|
| 852 |
+
f"\nCategorical mismatch\n"
|
| 853 |
+
f"batch_pos={batch_pos}\n"
|
| 854 |
+
f"row_idx={row_idx}\n"
|
| 855 |
+
f"feature={col}\n"
|
| 856 |
+
f"raw={raw}\n"
|
| 857 |
+
f"id={got_id}, expected={expected_id}\n"
|
| 858 |
+
f"valid={got_valid}, expected={expected_valid}"
|
| 859 |
+
)
|
| 860 |
+
|
| 861 |
+
# ====================================================
|
| 862 |
+
# numeric
|
| 863 |
+
# ====================================================
|
| 864 |
+
print("\n[NUMERIC FEATURES]")
|
| 865 |
+
|
| 866 |
+
numeric_loss_mask_by_nin = batch.get("numeric_loss_mask_by_nin", None)
|
| 867 |
+
|
| 868 |
+
for group in dataset.numeric_groups:
|
| 869 |
+
n_in = int(group["n_in"])
|
| 870 |
+
features = group["feature_names"]
|
| 871 |
+
|
| 872 |
+
values = batch["masked_numeric_values_by_nin"][n_in][batch_pos]
|
| 873 |
+
valids = batch["masked_numeric_valid_positions_by_nin"][n_in][batch_pos]
|
| 874 |
+
|
| 875 |
+
if numeric_loss_mask_by_nin is not None:
|
| 876 |
+
loss_mask = numeric_loss_mask_by_nin[n_in][batch_pos]
|
| 877 |
+
else:
|
| 878 |
+
loss_mask = None
|
| 879 |
+
|
| 880 |
+
print(f"\nGroup n_in={n_in}")
|
| 881 |
+
|
| 882 |
+
for i, feat in enumerate(features):
|
| 883 |
+
raw = row[feat]
|
| 884 |
+
raw_str = str(raw)
|
| 885 |
+
|
| 886 |
+
parsed, expected_valid = parse_numeric_cell(raw, n_in)
|
| 887 |
+
expected_norm = normalize_numeric_list(feat, parsed, expected_valid)
|
| 888 |
+
|
| 889 |
+
tensor_val = values[i].tolist()
|
| 890 |
+
got_valid = bool(valids[i].item())
|
| 891 |
+
|
| 892 |
+
is_loss_position = False
|
| 893 |
+
if loss_mask is not None:
|
| 894 |
+
is_loss_position = bool(loss_mask[i].item())
|
| 895 |
+
|
| 896 |
+
if is_loss_position:
|
| 897 |
+
ok = True
|
| 898 |
+
else:
|
| 899 |
+
value_ok = numeric_list_close(tensor_val, expected_norm)
|
| 900 |
+
valid_ok = (got_valid == expected_valid)
|
| 901 |
+
ok = value_ok and valid_ok
|
| 902 |
+
|
| 903 |
+
print(
|
| 904 |
+
f"{i:03d} | {feat} | "
|
| 905 |
+
f"raw={raw_str:<60} | "
|
| 906 |
+
f"tensor={tensor_val} | expected_norm={expected_norm} | "
|
| 907 |
+
f"valid={got_valid} | exp_valid={expected_valid} | "
|
| 908 |
+
f"loss_mask={is_loss_position} | ok={ok}"
|
| 909 |
+
)
|
| 910 |
+
|
| 911 |
+
if not ok:
|
| 912 |
+
raise AssertionError(
|
| 913 |
+
f"\nNumeric mismatch\n"
|
| 914 |
+
f"batch_pos={batch_pos}\n"
|
| 915 |
+
f"row_idx={row_idx}\n"
|
| 916 |
+
f"feature={feat}\n"
|
| 917 |
+
f"raw={raw}\n"
|
| 918 |
+
f"tensor={tensor_val}\n"
|
| 919 |
+
f"expected={parsed}\n"
|
| 920 |
+
f"valid={got_valid}, expected={expected_valid}"
|
| 921 |
+
)
|
| 922 |
+
|
| 923 |
+
# ====================================================
|
| 924 |
+
# vision
|
| 925 |
+
# ====================================================
|
| 926 |
+
print("\n[VISION]")
|
| 927 |
+
|
| 928 |
+
try:
|
| 929 |
+
relative_path = dataset.photo_map[sample_id]
|
| 930 |
+
expected_path = join_photo_root(dataset.photo_root, relative_path)
|
| 931 |
+
|
| 932 |
+
# Use the same logic as __getitem__: valid only if image can actually be loaded
|
| 933 |
+
_ = dataset.load_image(expected_path)
|
| 934 |
+
expected_valid = True
|
| 935 |
+
|
| 936 |
+
except Exception: # noqa
|
| 937 |
+
expected_path = None
|
| 938 |
+
expected_valid = False
|
| 939 |
+
|
| 940 |
+
got_valid = bool(batch["vision_valid_positions"][batch_pos].item())
|
| 941 |
+
img_shape = tuple(batch["pixel_values"][batch_pos].shape)
|
| 942 |
+
|
| 943 |
+
print("expected_path :", expected_path)
|
| 944 |
+
print("vision_valid :", got_valid)
|
| 945 |
+
print("image_shape :", img_shape)
|
| 946 |
+
|
| 947 |
+
if got_valid != expected_valid:
|
| 948 |
+
raise AssertionError(
|
| 949 |
+
f"\nVision validity mismatch\n"
|
| 950 |
+
f"batch_pos={batch_pos}\n"
|
| 951 |
+
f"row_idx={row_idx}\n"
|
| 952 |
+
f"expected={expected_valid}, got={got_valid}"
|
| 953 |
+
)
|
| 954 |
+
|
| 955 |
+
print("\n====================================================")
|
| 956 |
+
print("DEBUG CHECK PASSED")
|
| 957 |
+
print("====================================================\n")
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
def main():
|
| 961 |
+
dataset = SoilFormerDataset(
|
| 962 |
+
csv_path="data/tabular_data.csv",
|
| 963 |
+
photo_map_path="data/photo_map.json",
|
| 964 |
+
cat_vocab_path="data/cat_vocab.json",
|
| 965 |
+
numeric_vocab_path="data/numeric_vocab.json",
|
| 966 |
+
numeric_stats_path="data/tabular_meta_numeric_stats.csv",
|
| 967 |
+
photo_root="/Volumes/TOSHIBA EXT",
|
| 968 |
+
image_size=512,
|
| 969 |
+
id_column="id",
|
| 970 |
+
)
|
| 971 |
+
|
| 972 |
+
train_loader, eval_loader, train_generator = build_train_eval_dataloaders(dataset)
|
| 973 |
+
|
| 974 |
+
print("Dataset size:", len(dataset))
|
| 975 |
+
|
| 976 |
+
raw_batch = next(iter(eval_loader))
|
| 977 |
+
batch = dataset.perform_active_mask(
|
| 978 |
+
raw_batch,
|
| 979 |
+
cat_ratio=0.15,
|
| 980 |
+
num_ratio=0.15,
|
| 981 |
+
seed=42,
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
print("\nBatch check")
|
| 985 |
+
if "row_idx" in batch:
|
| 986 |
+
print("row_idx:", batch["row_idx"].shape, batch["row_idx"].dtype)
|
| 987 |
+
if "sample_id" in batch:
|
| 988 |
+
print("sample_id:", len(batch["sample_id"]))
|
| 989 |
+
|
| 990 |
+
print("original_cat_local_ids:", batch["original_cat_local_ids"].shape)
|
| 991 |
+
print("masked_cat_local_ids:", batch["masked_cat_local_ids"].shape)
|
| 992 |
+
print("original_cat_valid_positions:", batch["original_cat_valid_positions"].shape)
|
| 993 |
+
print("masked_cat_valid_positions:", batch["masked_cat_valid_positions"].shape)
|
| 994 |
+
print("cat_loss_mask:", batch["cat_loss_mask"].shape)
|
| 995 |
+
|
| 996 |
+
for k, v in batch["original_numeric_values_by_nin"].items():
|
| 997 |
+
print(f"original_numeric_values_by_nin[{k}]:", v.shape)
|
| 998 |
+
|
| 999 |
+
for k, v in batch["masked_numeric_values_by_nin"].items():
|
| 1000 |
+
print(f"masked_numeric_values_by_nin[{k}]:", v.shape)
|
| 1001 |
+
|
| 1002 |
+
for k, v in batch["original_numeric_valid_positions_by_nin"].items():
|
| 1003 |
+
print(f"original_numeric_valid_positions_by_nin[{k}]:", v.shape)
|
| 1004 |
+
|
| 1005 |
+
for k, v in batch["masked_numeric_valid_positions_by_nin"].items():
|
| 1006 |
+
print(f"masked_numeric_valid_positions_by_nin[{k}]:", v.shape)
|
| 1007 |
+
|
| 1008 |
+
for k, v in batch["numeric_loss_mask_by_nin"].items():
|
| 1009 |
+
print(f"numeric_loss_mask_by_nin[{k}]:", v.shape)
|
| 1010 |
+
|
| 1011 |
+
print("pixel_values:", batch["pixel_values"].shape)
|
| 1012 |
+
print("vision_valid_positions:", batch["vision_valid_positions"].shape)
|
| 1013 |
+
|
| 1014 |
+
print("\nTensor dtype check")
|
| 1015 |
+
print("masked cat ids dtype:", batch["masked_cat_local_ids"].dtype)
|
| 1016 |
+
print("masked numeric dtype:", next(iter(batch["masked_numeric_values_by_nin"].values())).dtype)
|
| 1017 |
+
print("image dtype:", batch["pixel_values"].dtype)
|
| 1018 |
+
|
| 1019 |
+
print("\nLoader test finished successfully")
|
| 1020 |
+
|
| 1021 |
+
debug_print_first_sample(dataset, batch, batch_pos=0)
|
| 1022 |
+
|
| 1023 |
+
|
| 1024 |
+
if __name__ == "__main__":
|
| 1025 |
+
main()
|
modelling/soilformer.py
ADDED
|
@@ -0,0 +1,696 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# soilformer.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F # noqa
|
| 12 |
+
|
| 13 |
+
from decode_categorical import CategoricalDecoder
|
| 14 |
+
from decode_numeric import NumericDecoder
|
| 15 |
+
from embed_categorical import (
|
| 16 |
+
CategoricalEmbedding,
|
| 17 |
+
build_cat_vocab_spec_from_meta,
|
| 18 |
+
get_categorical_feature_names_from_meta,
|
| 19 |
+
save_cat_vocab_json,
|
| 20 |
+
)
|
| 21 |
+
from embed_numeric import (
|
| 22 |
+
NumericEmbedding,
|
| 23 |
+
build_numeric_vocab_spec_from_meta,
|
| 24 |
+
)
|
| 25 |
+
from embed_vision_gemma3n import Gemma3nVisionFeatureExtractor
|
| 26 |
+
from layer import TabularImageGQALayer
|
| 27 |
+
from utils import load_json, save_json, get_dtype
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ============================================================
|
| 31 |
+
# SoilFormer
|
| 32 |
+
# ============================================================
|
| 33 |
+
|
| 34 |
+
class SoilFormer(nn.Module):
|
| 35 |
+
"""
|
| 36 |
+
Full model: embeddings -> TabularImageGQALayer stack -> decoders.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, config: Dict, device: Optional[str] = None):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.config = dict(config)
|
| 42 |
+
|
| 43 |
+
dtype = get_dtype(self.config.get("dtype", "bfloat16"))
|
| 44 |
+
dev = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 45 |
+
|
| 46 |
+
# ---- Tabular dims
|
| 47 |
+
cat_hidden = int(self.config["cat_hidden_size"])
|
| 48 |
+
num_hidden = int(self.config["numeric_hidden_size"])
|
| 49 |
+
if cat_hidden != num_hidden:
|
| 50 |
+
raise ValueError("Expect cat_hidden_size == numeric_hidden_size for one tabular stream.")
|
| 51 |
+
self.tabular_dim = cat_hidden
|
| 52 |
+
|
| 53 |
+
# ---- Embeddings
|
| 54 |
+
self.embed_cat = CategoricalEmbedding(
|
| 55 |
+
hidden_size=cat_hidden,
|
| 56 |
+
cat_vocab_json=self.config["cat_vocab_json"],
|
| 57 |
+
)
|
| 58 |
+
self.embed_num = NumericEmbedding(
|
| 59 |
+
hidden_size=num_hidden,
|
| 60 |
+
numeric_vocab_json=self.config["numeric_vocab_json"],
|
| 61 |
+
middle_size=self.config.get("numeric_encode_middle_size", None),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# ---- Decoders
|
| 65 |
+
self.decode_cat = CategoricalDecoder(
|
| 66 |
+
hidden_size=cat_hidden,
|
| 67 |
+
cat_vocab_json=self.config["cat_vocab_json"],
|
| 68 |
+
middle_size=self.config.get("cat_decode_middle_size", None),
|
| 69 |
+
homoscedastic=self.config.get("cat_homoscedastic", True),
|
| 70 |
+
)
|
| 71 |
+
self.decode_num = NumericDecoder(
|
| 72 |
+
hidden_size=num_hidden,
|
| 73 |
+
numeric_vocab_json=self.config["numeric_vocab_json"],
|
| 74 |
+
middle_size=self.config.get("numeric_decode_middle_size", None),
|
| 75 |
+
homoscedastic=self.config.get("num_homoscedastic", True),
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# ---- Vision
|
| 79 |
+
self.vision_extractor = Gemma3nVisionFeatureExtractor.from_pretrained_vision_only_dir(
|
| 80 |
+
model_dir=self.config["vision_model_dir"],
|
| 81 |
+
map_location="cpu",
|
| 82 |
+
num_output_tokens_reduced=self.config["vision_num_output_tokens_reduced"],
|
| 83 |
+
num_heads_for_token_reduction=self.config["vision_num_heads_for_token_reduction"],
|
| 84 |
+
reducer_bottleneck_dim=self.config["vision_reducer_bottleneck_dim"],
|
| 85 |
+
reducer_project_back=self.config["vision_reducer_project_back"],
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# ---- Layers
|
| 89 |
+
L = int(self.config["layer_num_layers"])
|
| 90 |
+
self.layers = nn.ModuleList([
|
| 91 |
+
TabularImageGQALayer(
|
| 92 |
+
tabular_dim=self.tabular_dim,
|
| 93 |
+
vision_dim=self.vision_extractor.get_actual_hidden_dim(),
|
| 94 |
+
num_query_heads=int(self.config["layer_num_query_heads"]),
|
| 95 |
+
num_kv_heads=int(self.config["layer_num_kv_heads"]),
|
| 96 |
+
head_dim=int(self.config["layer_head_dim"]),
|
| 97 |
+
mlp_ratio=float(self.config["layer_mlp_ratio"]),
|
| 98 |
+
dropout=float(self.config["layer_dropout"]),
|
| 99 |
+
)
|
| 100 |
+
for _ in range(L)
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
# ---- Move
|
| 104 |
+
self.to(device=dev, dtype=dtype)
|
| 105 |
+
|
| 106 |
+
def init_weights(self, std: float = 0.02):
|
| 107 |
+
self.embed_cat.init_weights(std=std)
|
| 108 |
+
self.embed_num.init_weights(std=std)
|
| 109 |
+
|
| 110 |
+
self.decode_cat.init_weights(std=std)
|
| 111 |
+
self.decode_num.init_weights(std=std)
|
| 112 |
+
|
| 113 |
+
self.vision_extractor.init_weights(std=std)
|
| 114 |
+
|
| 115 |
+
for blk in self.layers:
|
| 116 |
+
blk.init_weights(std=std)
|
| 117 |
+
|
| 118 |
+
def forward(
|
| 119 |
+
self,
|
| 120 |
+
cat_local_ids: torch.LongTensor, # [B, M_cat]
|
| 121 |
+
numeric_values_by_nin: Dict[int, torch.Tensor], # {n_in: [B, V, n_in]}
|
| 122 |
+
cat_valid_positions: Optional[torch.Tensor] = None, # [B, M_cat] bool
|
| 123 |
+
numeric_valid_positions_by_nin: Optional[Dict[int, torch.Tensor]] = None, # {n_in: [B,V] bool}
|
| 124 |
+
pixel_values: Optional[torch.Tensor] = None, # [B, 3, H, W]
|
| 125 |
+
vision_valid_positions: Optional[torch.Tensor] = None, # [B] bool OR indices [K]
|
| 126 |
+
):
|
| 127 |
+
# ----------------------------
|
| 128 |
+
# Embeddings (tabular)
|
| 129 |
+
# ----------------------------
|
| 130 |
+
x_cat, cat_mask = self.embed_cat(
|
| 131 |
+
local_ids=cat_local_ids,
|
| 132 |
+
valid_positions=cat_valid_positions,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
x_num, num_mask = self.embed_num(
|
| 136 |
+
values_by_nin=numeric_values_by_nin,
|
| 137 |
+
valid_positions_by_nin=numeric_valid_positions_by_nin,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
x_tab = torch.cat([x_cat, x_num], dim=1) # [B, T_tab, H]
|
| 141 |
+
|
| 142 |
+
B, T_tab, _ = x_tab.shape
|
| 143 |
+
M_cat = x_cat.size(1)
|
| 144 |
+
T_num = x_num.size(1)
|
| 145 |
+
|
| 146 |
+
# ----------------------------
|
| 147 |
+
# Tabular attention mask
|
| 148 |
+
# ----------------------------
|
| 149 |
+
cat_mask = cat_mask.to(device=x_tab.device, dtype=torch.long)
|
| 150 |
+
num_mask = num_mask.to(device=x_tab.device, dtype=torch.long)
|
| 151 |
+
|
| 152 |
+
if self.config["disable_tabular_attention_mask"]:
|
| 153 |
+
attention_mask_tab = torch.ones(B, T_tab, device=x_tab.device, dtype=torch.long)
|
| 154 |
+
else:
|
| 155 |
+
attention_mask_tab = torch.cat([cat_mask, num_mask], dim=1)
|
| 156 |
+
if attention_mask_tab.shape != (B, T_tab):
|
| 157 |
+
raise RuntimeError("Internal attention_mask_tab shape mismatch")
|
| 158 |
+
|
| 159 |
+
# ----------------------------
|
| 160 |
+
# Vision features
|
| 161 |
+
# ----------------------------
|
| 162 |
+
if pixel_values is None:
|
| 163 |
+
|
| 164 |
+
vision_features = None
|
| 165 |
+
vision_mask = None
|
| 166 |
+
|
| 167 |
+
else:
|
| 168 |
+
|
| 169 |
+
vision_features, vision_mask = self.vision_extractor(
|
| 170 |
+
pixel_values=pixel_values,
|
| 171 |
+
valid_positions=vision_valid_positions,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if vision_features.shape[0] != B:
|
| 175 |
+
raise ValueError("vision_features batch mismatch with tabular batch")
|
| 176 |
+
|
| 177 |
+
if vision_mask.shape[0] != B or vision_mask.shape[1] != vision_features.shape[1]:
|
| 178 |
+
raise ValueError("vision_mask shape mismatch with vision_features")
|
| 179 |
+
|
| 180 |
+
vision_mask = vision_mask.to(
|
| 181 |
+
device=attention_mask_tab.device,
|
| 182 |
+
dtype=attention_mask_tab.dtype,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
# ----------------------------
|
| 186 |
+
# Transformer blocks
|
| 187 |
+
# ----------------------------
|
| 188 |
+
for blk in self.layers: # type: TabularImageGQALayer
|
| 189 |
+
x_tab = blk(
|
| 190 |
+
x_tab=x_tab,
|
| 191 |
+
attention_mask=attention_mask_tab,
|
| 192 |
+
vision_features=vision_features,
|
| 193 |
+
vision_mask=vision_mask
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# ----------------------------
|
| 197 |
+
# Slice outputs
|
| 198 |
+
# ----------------------------
|
| 199 |
+
x_cat_out = x_tab[:, :M_cat, :]
|
| 200 |
+
x_num_out = x_tab[:, M_cat:M_cat + T_num, :]
|
| 201 |
+
|
| 202 |
+
# ----------------------------
|
| 203 |
+
# Decode
|
| 204 |
+
# ----------------------------
|
| 205 |
+
cat_logits_padded, cat_s, valid_class_mask = self.decode_cat(
|
| 206 |
+
x_cat_out,
|
| 207 |
+
return_padded=True,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
value_by_nin, s_by_nin = self.decode_num(
|
| 211 |
+
x_num_out
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
return cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab
|
| 215 |
+
|
| 216 |
+
def _checkpoint_state_dict(self) -> Dict[str, torch.Tensor]:
|
| 217 |
+
"""
|
| 218 |
+
State dict used for save/load.
|
| 219 |
+
|
| 220 |
+
Excludes pretrained frozen vision weights:
|
| 221 |
+
- vision_extractor.vision_tower.*
|
| 222 |
+
- vision_extractor.embed_vision.*
|
| 223 |
+
|
| 224 |
+
Keeps reducer weights if reducer exists.
|
| 225 |
+
"""
|
| 226 |
+
full_sd = self.state_dict()
|
| 227 |
+
out = {}
|
| 228 |
+
|
| 229 |
+
for k, v in full_sd.items():
|
| 230 |
+
if k.startswith("vision_extractor.vision_tower."):
|
| 231 |
+
continue
|
| 232 |
+
if k.startswith("vision_extractor.embed_vision."):
|
| 233 |
+
continue
|
| 234 |
+
out[k] = v
|
| 235 |
+
|
| 236 |
+
return out
|
| 237 |
+
|
| 238 |
+
def save_weights(self, path: str):
|
| 239 |
+
"""
|
| 240 |
+
Save model weights needed for SoilFormer training/inference,
|
| 241 |
+
excluding pretrained frozen vision weights.
|
| 242 |
+
"""
|
| 243 |
+
payload = {
|
| 244 |
+
"model_state_dict": self._checkpoint_state_dict(),
|
| 245 |
+
"config": self.config,
|
| 246 |
+
}
|
| 247 |
+
torch.save(payload, path)
|
| 248 |
+
|
| 249 |
+
def load_weights(self, path: str, map_location: str = "cpu", strict: bool = True):
|
| 250 |
+
"""
|
| 251 |
+
Load weights saved by save_weights().
|
| 252 |
+
|
| 253 |
+
Only the checkpoint-managed subset is loaded:
|
| 254 |
+
- embeddings / decoders / layers
|
| 255 |
+
- vision_extractor.reducer.* (if present)
|
| 256 |
+
|
| 257 |
+
Pretrained frozen vision weights are ignored here and are expected
|
| 258 |
+
to come from vision_model_dir during model construction.
|
| 259 |
+
"""
|
| 260 |
+
ckpt = torch.load(path, map_location=map_location)
|
| 261 |
+
|
| 262 |
+
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
| 263 |
+
sd = ckpt["model_state_dict"]
|
| 264 |
+
elif isinstance(ckpt, dict):
|
| 265 |
+
sd = ckpt
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError(f"Unsupported checkpoint format: {path}")
|
| 268 |
+
|
| 269 |
+
expected_sd = self._checkpoint_state_dict()
|
| 270 |
+
|
| 271 |
+
# Only keep keys that belong to the checkpoint-managed subset
|
| 272 |
+
loadable_sd = {k: v for k, v in sd.items() if k in expected_sd}
|
| 273 |
+
|
| 274 |
+
missing = sorted(set(expected_sd.keys()) - set(loadable_sd.keys()))
|
| 275 |
+
unexpected = sorted(set(sd.keys()) - set(expected_sd.keys()))
|
| 276 |
+
|
| 277 |
+
# Actually load
|
| 278 |
+
load_info = self.load_state_dict(loadable_sd, strict=False)
|
| 279 |
+
|
| 280 |
+
# PyTorch may still report missing keys from the full model state_dict;
|
| 281 |
+
# keep only checkpoint-managed ones.
|
| 282 |
+
missing_after_load = [
|
| 283 |
+
k for k in load_info.missing_keys
|
| 284 |
+
if k in expected_sd
|
| 285 |
+
]
|
| 286 |
+
unexpected_after_load = [
|
| 287 |
+
k for k in load_info.unexpected_keys
|
| 288 |
+
if k in expected_sd
|
| 289 |
+
]
|
| 290 |
+
|
| 291 |
+
# Merge both sources of mismatch info
|
| 292 |
+
missing_final = sorted(set(missing) | set(missing_after_load))
|
| 293 |
+
unexpected_final = sorted(set(unexpected) | set(unexpected_after_load))
|
| 294 |
+
|
| 295 |
+
if strict and (missing_final or unexpected_final):
|
| 296 |
+
raise RuntimeError(
|
| 297 |
+
"Checkpoint load mismatch.\n"
|
| 298 |
+
f"Missing keys: {missing_final}\n"
|
| 299 |
+
f"Unexpected keys: {unexpected_final}"
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
return {
|
| 303 |
+
"missing_keys": missing_final,
|
| 304 |
+
"unexpected_keys": unexpected_final,
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def loss_function(
|
| 309 |
+
x_cat: torch.Tensor, # [B,M,Cmax] padded logits
|
| 310 |
+
s_cat: torch.Tensor, # [B,M] log-variance
|
| 311 |
+
y_cat: torch.Tensor, # [B,M] class index
|
| 312 |
+
loss_mask_cat: torch.Tensor, # [B,M] 0/1
|
| 313 |
+
valid_class_mask: torch.Tensor, # [M,Cmax] bool
|
| 314 |
+
x_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]}
|
| 315 |
+
s_num: Dict[int, torch.Tensor], # {n_in: [B,V]}
|
| 316 |
+
y_num: Dict[int, torch.Tensor], # {n_in: [B,V,n_in]}
|
| 317 |
+
loss_mask_num: Dict[int, torch.Tensor], # {n_in: [B,V]} 0/1
|
| 318 |
+
cat_temperature: float = 1.0,
|
| 319 |
+
reduction: str = "mean", # "mean" or "sum"
|
| 320 |
+
eps: float = 1e-12,
|
| 321 |
+
cat_s_bound: Optional[float] = None,
|
| 322 |
+
num_s_bound: Optional[float] = None,
|
| 323 |
+
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 324 |
+
"""
|
| 325 |
+
Strict loss for SoilFormer.
|
| 326 |
+
|
| 327 |
+
Categorical:
|
| 328 |
+
- Uses per-column CE over the valid class range only.
|
| 329 |
+
- Does NOT rely on padded logits values.
|
| 330 |
+
- s_cat[b,m] = log sigma^2 for categorical column m.
|
| 331 |
+
|
| 332 |
+
Numeric:
|
| 333 |
+
- Per-variable MSE averaged over n_in dimensions.
|
| 334 |
+
- s_num[n_in][b,v] = log sigma^2 for numeric variable v.
|
| 335 |
+
|
| 336 |
+
Optional soft bound:
|
| 337 |
+
If cat_s_bound or num_s_bound is not None, apply
|
| 338 |
+
s <- bound * tanh(s / bound)
|
| 339 |
+
before using s in heteroscedastic weighting.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
total_loss: scalar (float32)
|
| 343 |
+
stats: dict with cat_loss, num_loss, cat_base, num_base, counts...
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
def _soft_bound_logvar(s_: torch.Tensor, bound: Optional[float]) -> torch.Tensor:
|
| 347 |
+
if bound is None:
|
| 348 |
+
return s_
|
| 349 |
+
b = float(bound)
|
| 350 |
+
if b <= 0:
|
| 351 |
+
# Turn off weighting by signalling a non-positive bound
|
| 352 |
+
return torch.zeros_like(s_)
|
| 353 |
+
return b * torch.tanh(s_ / b)
|
| 354 |
+
|
| 355 |
+
# ---------------------------------------------------
|
| 356 |
+
# 1) Categorical loss (strict per-column CE)
|
| 357 |
+
# ---------------------------------------------------
|
| 358 |
+
if x_cat.dim() != 3:
|
| 359 |
+
raise ValueError(f"x_cat must be [B,M,Cmax], got {tuple(x_cat.shape)}")
|
| 360 |
+
|
| 361 |
+
B, M, Cmax = x_cat.shape
|
| 362 |
+
|
| 363 |
+
if s_cat.shape != (B, M):
|
| 364 |
+
raise ValueError(f"s_cat must be [B,M]=({B},{M}), got {tuple(s_cat.shape)}")
|
| 365 |
+
if y_cat.shape != (B, M):
|
| 366 |
+
raise ValueError(f"y_cat must be [B,M]=({B},{M}), got {tuple(y_cat.shape)}")
|
| 367 |
+
if loss_mask_cat.shape != (B, M):
|
| 368 |
+
raise ValueError(f"loss_mask_cat must be [B,M]=({B},{M}), got {tuple(loss_mask_cat.shape)}")
|
| 369 |
+
if valid_class_mask.shape != (M, Cmax):
|
| 370 |
+
raise ValueError(
|
| 371 |
+
f"valid_class_mask must be [M,Cmax]=({M},{Cmax}), got {tuple(valid_class_mask.shape)}"
|
| 372 |
+
)
|
| 373 |
+
|
| 374 |
+
x_cat_f = x_cat.float()
|
| 375 |
+
s_cat_f = _soft_bound_logvar(s_cat.float(), cat_s_bound)
|
| 376 |
+
y_cat_l = y_cat.long()
|
| 377 |
+
mcat = loss_mask_cat.float()
|
| 378 |
+
valid_class_mask = valid_class_mask.to(device=x_cat.device, dtype=torch.bool)
|
| 379 |
+
|
| 380 |
+
if cat_temperature != 1.0:
|
| 381 |
+
x_cat_f = x_cat_f / float(cat_temperature)
|
| 382 |
+
|
| 383 |
+
cat_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
|
| 384 |
+
cat_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
|
| 385 |
+
cat_correct_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
|
| 386 |
+
|
| 387 |
+
# denominator = number of actively supervised categorical cells
|
| 388 |
+
cat_denom = mcat.sum().clamp_min(float(eps))
|
| 389 |
+
|
| 390 |
+
for m in range(M):
|
| 391 |
+
cm = int(valid_class_mask[m].sum().item()) # real class count for column m
|
| 392 |
+
if cm <= 0:
|
| 393 |
+
raise ValueError(f"Column {m} has no valid classes")
|
| 394 |
+
|
| 395 |
+
logits_m = x_cat_f[:, m, :cm] # [B, C_m]
|
| 396 |
+
target_m = y_cat_l[:, m] # [B]
|
| 397 |
+
s_m = s_cat_f[:, m] # [B]
|
| 398 |
+
mask_m = mcat[:, m] # [B]
|
| 399 |
+
|
| 400 |
+
active = mask_m > 0
|
| 401 |
+
if active.any():
|
| 402 |
+
tgt_active = target_m[active]
|
| 403 |
+
if (tgt_active < 0).any() or (tgt_active >= cm).any():
|
| 404 |
+
raise ValueError(f"y_cat contains invalid class id for categorical column {m}")
|
| 405 |
+
|
| 406 |
+
target_m_safe = target_m.clone()
|
| 407 |
+
target_m_safe[~active] = 0
|
| 408 |
+
|
| 409 |
+
ce_m = F.cross_entropy(
|
| 410 |
+
logits_m,
|
| 411 |
+
target_m_safe,
|
| 412 |
+
reduction="none",
|
| 413 |
+
) # [B], float32
|
| 414 |
+
|
| 415 |
+
# ---------------------------------------------------
|
| 416 |
+
# accuracy (only count active positions)
|
| 417 |
+
# ---------------------------------------------------
|
| 418 |
+
pred_m = logits_m.argmax(dim=-1) # [B]
|
| 419 |
+
correct_m = (pred_m == target_m_safe) & active # [B]
|
| 420 |
+
cat_correct_acc = cat_correct_acc + correct_m.float().sum()
|
| 421 |
+
|
| 422 |
+
# heteroscedastic weighting: exp(-s) * CE + s
|
| 423 |
+
L_m = torch.exp(-s_m) * ce_m + s_m # [B]
|
| 424 |
+
|
| 425 |
+
cat_loss_acc = cat_loss_acc + (L_m * mask_m).sum()
|
| 426 |
+
cat_base_acc = cat_base_acc + (ce_m * mask_m).sum()
|
| 427 |
+
|
| 428 |
+
if reduction == "mean":
|
| 429 |
+
cat_loss = cat_loss_acc / cat_denom
|
| 430 |
+
cat_base = cat_base_acc / cat_denom
|
| 431 |
+
elif reduction == "sum":
|
| 432 |
+
cat_loss = cat_loss_acc
|
| 433 |
+
cat_base = cat_base_acc
|
| 434 |
+
else:
|
| 435 |
+
raise ValueError(f"Unsupported reduction: {reduction}")
|
| 436 |
+
cat_acc = cat_correct_acc / cat_denom
|
| 437 |
+
|
| 438 |
+
# ---------------------------------------------------
|
| 439 |
+
# 2) Numeric loss (per-variable heteroscedastic MSE)
|
| 440 |
+
# ---------------------------------------------------
|
| 441 |
+
num_loss_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
|
| 442 |
+
num_base_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
|
| 443 |
+
num_denom_acc = torch.zeros((), device=x_cat.device, dtype=torch.float32)
|
| 444 |
+
|
| 445 |
+
for n_in, x in x_num.items():
|
| 446 |
+
if n_in not in y_num or n_in not in s_num or n_in not in loss_mask_num:
|
| 447 |
+
raise KeyError(f"Missing key n_in={n_in} in y_num/s_num/loss_mask_num")
|
| 448 |
+
|
| 449 |
+
y = y_num[n_in]
|
| 450 |
+
s = s_num[n_in]
|
| 451 |
+
m = loss_mask_num[n_in]
|
| 452 |
+
|
| 453 |
+
if x.shape != y.shape:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
f"x_num[{n_in}] and y_num[{n_in}] shape mismatch: "
|
| 456 |
+
f"{tuple(x.shape)} vs {tuple(y.shape)}"
|
| 457 |
+
)
|
| 458 |
+
if x.dim() != 3:
|
| 459 |
+
raise ValueError(f"x_num[{n_in}] must be [B,V,n_in], got {tuple(x.shape)}")
|
| 460 |
+
|
| 461 |
+
Bb, V, Nin = x.shape
|
| 462 |
+
if Nin != n_in:
|
| 463 |
+
raise ValueError(f"x_num[{n_in}] last dim mismatch: got {Nin}, expected {n_in}")
|
| 464 |
+
if s.shape != (Bb, V):
|
| 465 |
+
raise ValueError(f"s_num[{n_in}] must be [B,V], got {tuple(s.shape)}")
|
| 466 |
+
if m.shape != (Bb, V):
|
| 467 |
+
raise ValueError(f"loss_mask_num[{n_in}] must be [B,V], got {tuple(m.shape)}")
|
| 468 |
+
|
| 469 |
+
x_f = x.float()
|
| 470 |
+
y_f = y.float()
|
| 471 |
+
s_f = _soft_bound_logvar(s.float(), num_s_bound)
|
| 472 |
+
m_f = m.float()
|
| 473 |
+
|
| 474 |
+
# base numeric loss per variable: mean over n_in dims
|
| 475 |
+
mse = (x_f - y_f).pow(2).mean(dim=-1) # [B,V]
|
| 476 |
+
|
| 477 |
+
# heteroscedastic weighting: exp(-s) * mse + s
|
| 478 |
+
L = torch.exp(-s_f) * mse + s_f # [B,V]
|
| 479 |
+
|
| 480 |
+
num_loss_acc = num_loss_acc + (L * m_f).sum()
|
| 481 |
+
num_base_acc = num_base_acc + (mse * m_f).sum()
|
| 482 |
+
num_denom_acc = num_denom_acc + m_f.sum()
|
| 483 |
+
|
| 484 |
+
num_denom = num_denom_acc.clamp_min(float(eps))
|
| 485 |
+
|
| 486 |
+
if reduction == "mean":
|
| 487 |
+
num_loss = num_loss_acc / num_denom
|
| 488 |
+
num_base = num_base_acc / num_denom
|
| 489 |
+
elif reduction == "sum":
|
| 490 |
+
num_loss = num_loss_acc
|
| 491 |
+
num_base = num_base_acc
|
| 492 |
+
else:
|
| 493 |
+
raise ValueError(f"Unsupported reduction: {reduction}")
|
| 494 |
+
|
| 495 |
+
# ---------------------------------------------------
|
| 496 |
+
# 3) Total
|
| 497 |
+
# ---------------------------------------------------
|
| 498 |
+
total = cat_loss + num_loss
|
| 499 |
+
|
| 500 |
+
stats = {
|
| 501 |
+
"total": total.detach(),
|
| 502 |
+
"cat_loss": cat_loss.detach(),
|
| 503 |
+
"num_loss": num_loss.detach(),
|
| 504 |
+
"cat_base": cat_base.detach(),
|
| 505 |
+
"num_base": num_base.detach(),
|
| 506 |
+
"cat_count": cat_denom.detach(),
|
| 507 |
+
"num_count": num_denom.detach(),
|
| 508 |
+
"cat_acc": cat_acc.detach(),
|
| 509 |
+
}
|
| 510 |
+
return total, stats
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
# ============================================================
|
| 514 |
+
# DEMO
|
| 515 |
+
# ============================================================
|
| 516 |
+
|
| 517 |
+
def _demo_main():
|
| 518 |
+
import argparse
|
| 519 |
+
|
| 520 |
+
parser = argparse.ArgumentParser()
|
| 521 |
+
parser.add_argument("--config_json", type=str, default="config/config_model.json")
|
| 522 |
+
parser.add_argument("--batch_size", type=int, default=2)
|
| 523 |
+
parser.add_argument("--with_vision", action="store_true")
|
| 524 |
+
args = parser.parse_args()
|
| 525 |
+
|
| 526 |
+
cfg = load_json(args.config_json)
|
| 527 |
+
|
| 528 |
+
print("===== Loaded config =====")
|
| 529 |
+
print(json.dumps(cfg, ensure_ascii=False, indent=2))
|
| 530 |
+
|
| 531 |
+
# --------------------------------------------------
|
| 532 |
+
# Ensure vocab files exist
|
| 533 |
+
# --------------------------------------------------
|
| 534 |
+
tabular_meta = load_json(cfg["tabular_meta"])
|
| 535 |
+
|
| 536 |
+
if not os.path.isfile(cfg["cat_vocab_json"]):
|
| 537 |
+
cat_names = get_categorical_feature_names_from_meta(tabular_meta)
|
| 538 |
+
vocab = build_cat_vocab_spec_from_meta(tabular_meta, cat_names)
|
| 539 |
+
Path(cfg["cat_vocab_json"]).parent.mkdir(parents=True, exist_ok=True)
|
| 540 |
+
save_cat_vocab_json(vocab, cfg["cat_vocab_json"])
|
| 541 |
+
print(f"[demo] Built cat_vocab_json at {cfg['cat_vocab_json']}")
|
| 542 |
+
|
| 543 |
+
if not os.path.isfile(cfg["numeric_vocab_json"]):
|
| 544 |
+
spec = build_numeric_vocab_spec_from_meta(tabular_meta)
|
| 545 |
+
Path(cfg["numeric_vocab_json"]).parent.mkdir(parents=True, exist_ok=True)
|
| 546 |
+
save_json(spec, cfg["numeric_vocab_json"])
|
| 547 |
+
print(f"[demo] Built numeric_vocab_json at {cfg['numeric_vocab_json']}")
|
| 548 |
+
|
| 549 |
+
# --------------------------------------------------
|
| 550 |
+
# Build model
|
| 551 |
+
# --------------------------------------------------
|
| 552 |
+
model = SoilFormer(cfg)
|
| 553 |
+
model.init_weights()
|
| 554 |
+
model.eval()
|
| 555 |
+
|
| 556 |
+
device = next(model.parameters()).device
|
| 557 |
+
dtype = next(model.parameters()).dtype
|
| 558 |
+
|
| 559 |
+
B = args.batch_size
|
| 560 |
+
|
| 561 |
+
# --------------------------------------------------
|
| 562 |
+
# Build dummy categorical inputs
|
| 563 |
+
# --------------------------------------------------
|
| 564 |
+
cat_spec = load_json(cfg["cat_vocab_json"])
|
| 565 |
+
cat_items = sorted(cat_spec.items(), key=lambda x: x[1]["col_id"])
|
| 566 |
+
M_cat = len(cat_items)
|
| 567 |
+
|
| 568 |
+
cat_local_ids = torch.zeros(B, M_cat, dtype=torch.long, device=device)
|
| 569 |
+
cat_valid_positions = torch.ones(B, M_cat, dtype=torch.bool, device=device)
|
| 570 |
+
|
| 571 |
+
# --------------------------------------------------
|
| 572 |
+
# Build dummy numeric inputs
|
| 573 |
+
# --------------------------------------------------
|
| 574 |
+
num_spec = load_json(cfg["numeric_vocab_json"])
|
| 575 |
+
|
| 576 |
+
numeric_values_by_nin: Dict[int, torch.Tensor] = {}
|
| 577 |
+
numeric_valid_positions_by_nin: Dict[int, torch.Tensor] = {}
|
| 578 |
+
|
| 579 |
+
for g in num_spec["groups"]:
|
| 580 |
+
n_in = int(g["n_in"])
|
| 581 |
+
V = len(g["feature_names"])
|
| 582 |
+
|
| 583 |
+
numeric_values_by_nin[n_in] = torch.randn(B, V, n_in, device=device, dtype=dtype)
|
| 584 |
+
numeric_valid_positions_by_nin[n_in] = torch.ones(B, V, dtype=torch.bool, device=device)
|
| 585 |
+
|
| 586 |
+
# --------------------------------------------------
|
| 587 |
+
# Build dummy vision inputs
|
| 588 |
+
# --------------------------------------------------
|
| 589 |
+
if args.with_vision:
|
| 590 |
+
pixel_values = torch.randn(B, 3, 224, 224, device=device, dtype=dtype)
|
| 591 |
+
vision_valid_positions = torch.ones(B, dtype=torch.bool, device=device)
|
| 592 |
+
else:
|
| 593 |
+
pixel_values = None
|
| 594 |
+
vision_valid_positions = None
|
| 595 |
+
|
| 596 |
+
# --------------------------------------------------
|
| 597 |
+
# Vision debug
|
| 598 |
+
# --------------------------------------------------
|
| 599 |
+
print("\n===== Vision debug =====")
|
| 600 |
+
if pixel_values is None:
|
| 601 |
+
print("pixel_values: None")
|
| 602 |
+
print("vision_features: None")
|
| 603 |
+
print("vision_mask: None")
|
| 604 |
+
else:
|
| 605 |
+
print("pixel_values:", tuple(pixel_values.shape), pixel_values.dtype, pixel_values.device)
|
| 606 |
+
with torch.no_grad():
|
| 607 |
+
vision_features, vision_mask = model.vision_extractor.forward(
|
| 608 |
+
pixel_values=pixel_values,
|
| 609 |
+
valid_positions=vision_valid_positions,
|
| 610 |
+
)
|
| 611 |
+
print("vision_features:", tuple(vision_features.shape), vision_features.dtype, vision_features.device)
|
| 612 |
+
print("vision_mask:", tuple(vision_mask.shape), vision_mask.dtype, vision_mask.device)
|
| 613 |
+
|
| 614 |
+
# --------------------------------------------------
|
| 615 |
+
# Forward
|
| 616 |
+
# --------------------------------------------------
|
| 617 |
+
with torch.no_grad():
|
| 618 |
+
cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, x_tab = model.forward(
|
| 619 |
+
cat_local_ids=cat_local_ids, # noqa
|
| 620 |
+
numeric_values_by_nin=numeric_values_by_nin,
|
| 621 |
+
cat_valid_positions=cat_valid_positions,
|
| 622 |
+
numeric_valid_positions_by_nin=numeric_valid_positions_by_nin,
|
| 623 |
+
pixel_values=pixel_values,
|
| 624 |
+
vision_valid_positions=vision_valid_positions,
|
| 625 |
+
)
|
| 626 |
+
|
| 627 |
+
print("\n===== SoilFormer demo =====")
|
| 628 |
+
print("cat_local_ids:", tuple(cat_local_ids.shape))
|
| 629 |
+
print("cat_valid_positions:", tuple(cat_valid_positions.shape))
|
| 630 |
+
print("numeric_values_by_nin:", {k: tuple(v.shape) for k, v in numeric_values_by_nin.items()})
|
| 631 |
+
print("numeric_valid_positions_by_nin:", {k: tuple(v.shape) for k, v in numeric_valid_positions_by_nin.items()})
|
| 632 |
+
print("x_tab_final:", tuple(x_tab.shape), x_tab.dtype, x_tab.device)
|
| 633 |
+
|
| 634 |
+
print("Categorical outputs:")
|
| 635 |
+
print("cat_logits_padded:", tuple(cat_logits_padded.shape), cat_logits_padded.dtype, cat_logits_padded.device)
|
| 636 |
+
print("cat_s:", tuple(cat_s.shape), cat_s.dtype, cat_s.device)
|
| 637 |
+
|
| 638 |
+
print("Numeric decoded values:", {k: tuple(v.shape) for k, v in value_by_nin.items()})
|
| 639 |
+
print("Numeric decoded s:", {k: tuple(s.shape) for k, s in s_by_nin.items()})
|
| 640 |
+
|
| 641 |
+
# --------------------------------------------------
|
| 642 |
+
# Loss debug
|
| 643 |
+
# --------------------------------------------------
|
| 644 |
+
print("\n===== Loss debug =====")
|
| 645 |
+
|
| 646 |
+
if cat_logits_padded.dim() != 3:
|
| 647 |
+
raise RuntimeError(f"cat_logits_padded must be [B,M,Cmax], got {tuple(cat_logits_padded.shape)}")
|
| 648 |
+
|
| 649 |
+
B_logits, M_cat2, Cmax2 = cat_logits_padded.shape
|
| 650 |
+
if cat_s.shape != (B_logits, M_cat2):
|
| 651 |
+
raise RuntimeError(f"cat_s shape mismatch: got {tuple(cat_s.shape)} expected {(B_logits, M_cat2)}")
|
| 652 |
+
|
| 653 |
+
# Build dummy categorical targets within valid class ranges
|
| 654 |
+
num_classes = [int(s["num_classes"]) for _, s in cat_items]
|
| 655 |
+
if len(num_classes) != M_cat2:
|
| 656 |
+
raise RuntimeError("M_cat mismatch between vocab and model output")
|
| 657 |
+
|
| 658 |
+
y_cat = torch.zeros(B_logits, M_cat2, dtype=torch.long, device=device)
|
| 659 |
+
for m, cm in enumerate(num_classes):
|
| 660 |
+
y_cat[:, m] = torch.randint(low=0, high=cm, size=(B_logits,), device=device)
|
| 661 |
+
|
| 662 |
+
mask_cat = torch.ones(B_logits, M_cat2, dtype=torch.long, device=device)
|
| 663 |
+
|
| 664 |
+
# Build dummy numeric targets and masks
|
| 665 |
+
y_num = {
|
| 666 |
+
n_in: torch.randn_like(x_pred)
|
| 667 |
+
for n_in, x_pred in value_by_nin.items()
|
| 668 |
+
}
|
| 669 |
+
|
| 670 |
+
mask_num = {
|
| 671 |
+
n_in: torch.ones(x_pred.size(0), x_pred.size(1), dtype=torch.long, device=x_pred.device)
|
| 672 |
+
for n_in, x_pred in value_by_nin.items()
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
total_loss, stats = loss_function(
|
| 676 |
+
x_cat=cat_logits_padded,
|
| 677 |
+
s_cat=cat_s,
|
| 678 |
+
y_cat=y_cat,
|
| 679 |
+
loss_mask_cat=mask_cat,
|
| 680 |
+
x_num=value_by_nin,
|
| 681 |
+
s_num=s_by_nin,
|
| 682 |
+
y_num=y_num,
|
| 683 |
+
loss_mask_num=mask_num,
|
| 684 |
+
reduction="mean",
|
| 685 |
+
valid_class_mask=valid_class_mask
|
| 686 |
+
)
|
| 687 |
+
|
| 688 |
+
print("total_loss:", float(total_loss))
|
| 689 |
+
print("stats:", {k: float(v) for k, v in stats.items()})
|
| 690 |
+
|
| 691 |
+
if not torch.isfinite(total_loss):
|
| 692 |
+
raise RuntimeError("Loss is not finite!")
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
if __name__ == "__main__":
|
| 696 |
+
_demo_main()
|
modelling/train.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from torch.optim import AdamW
|
| 11 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LinearLR, SequentialLR
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
|
| 14 |
+
from loader import SoilFormerDataset, build_train_eval_dataloaders
|
| 15 |
+
from soilformer import SoilFormer, loss_function
|
| 16 |
+
from utils import get_dtype, load_json, save_json
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import wandb
|
| 20 |
+
except ImportError: # pragma: no cover
|
| 21 |
+
wandb = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def set_seed(seed: int, deterministic: bool = True) -> None:
|
| 25 |
+
random.seed(seed)
|
| 26 |
+
np.random.seed(seed)
|
| 27 |
+
torch.manual_seed(seed)
|
| 28 |
+
if torch.cuda.is_available():
|
| 29 |
+
torch.cuda.manual_seed(seed)
|
| 30 |
+
torch.cuda.manual_seed_all(seed)
|
| 31 |
+
|
| 32 |
+
if deterministic:
|
| 33 |
+
torch.backends.cudnn.deterministic = True
|
| 34 |
+
torch.backends.cudnn.benchmark = False
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def resolve_device(device_str: str) -> torch.device:
|
| 38 |
+
device_str = device_str.lower()
|
| 39 |
+
|
| 40 |
+
if device_str == "cuda":
|
| 41 |
+
if not torch.cuda.is_available():
|
| 42 |
+
raise RuntimeError("config requests cuda, but CUDA is not available")
|
| 43 |
+
return torch.device("cuda")
|
| 44 |
+
|
| 45 |
+
if device_str == "mps":
|
| 46 |
+
if not torch.backends.mps.is_available():
|
| 47 |
+
raise RuntimeError("config requests mps, but MPS is not available")
|
| 48 |
+
return torch.device("mps")
|
| 49 |
+
|
| 50 |
+
if device_str == "cpu":
|
| 51 |
+
return torch.device("cpu")
|
| 52 |
+
|
| 53 |
+
raise ValueError(f"Unsupported device: {device_str}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def move_batch_to_device(batch: Dict, device: torch.device, float_dtype: torch.dtype) -> Dict:
|
| 57 |
+
out = {}
|
| 58 |
+
for key, value in batch.items():
|
| 59 |
+
if isinstance(value, torch.Tensor):
|
| 60 |
+
if value.dtype.is_floating_point:
|
| 61 |
+
out[key] = value.to(device=device, dtype=float_dtype, non_blocking=True)
|
| 62 |
+
else:
|
| 63 |
+
out[key] = value.to(device=device, non_blocking=True)
|
| 64 |
+
elif isinstance(value, dict):
|
| 65 |
+
sub = {}
|
| 66 |
+
for sub_key, sub_value in value.items():
|
| 67 |
+
if isinstance(sub_value, torch.Tensor):
|
| 68 |
+
if sub_value.dtype.is_floating_point:
|
| 69 |
+
sub[sub_key] = sub_value.to(device=device, dtype=float_dtype, non_blocking=True)
|
| 70 |
+
else:
|
| 71 |
+
sub[sub_key] = sub_value.to(device=device, non_blocking=True)
|
| 72 |
+
else:
|
| 73 |
+
sub[sub_key] = sub_value
|
| 74 |
+
out[key] = sub
|
| 75 |
+
else:
|
| 76 |
+
out[key] = value
|
| 77 |
+
return out
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def build_scheduler(
|
| 81 |
+
optimizer: torch.optim.Optimizer,
|
| 82 |
+
scheduler_cfg: Dict,
|
| 83 |
+
):
|
| 84 |
+
scheduler_type = str(scheduler_cfg.get("type", "none")).lower()
|
| 85 |
+
|
| 86 |
+
if scheduler_type == "none":
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
warmup_epochs = int(scheduler_cfg.get("warmup_epochs", 0))
|
| 90 |
+
warmup_start_factor = float(scheduler_cfg.get("warmup_start_factor", 0.1))
|
| 91 |
+
|
| 92 |
+
if scheduler_type == "cosine":
|
| 93 |
+
total_epochs = int(scheduler_cfg["total_epochs"])
|
| 94 |
+
eta_min = float(scheduler_cfg.get("eta_min", 1e-6))
|
| 95 |
+
|
| 96 |
+
if warmup_epochs > 0:
|
| 97 |
+
t_max = int(scheduler_cfg.get("t_max", total_epochs - warmup_epochs))
|
| 98 |
+
if t_max <= 0:
|
| 99 |
+
raise ValueError(
|
| 100 |
+
f"Invalid cosine scheduler config: total_epochs={total_epochs}, "
|
| 101 |
+
f"warmup_epochs={warmup_epochs}, resulting T_max={t_max}"
|
| 102 |
+
)
|
| 103 |
+
else:
|
| 104 |
+
t_max = int(scheduler_cfg.get("t_max", total_epochs))
|
| 105 |
+
|
| 106 |
+
main_scheduler = CosineAnnealingLR(
|
| 107 |
+
optimizer,
|
| 108 |
+
T_max=t_max,
|
| 109 |
+
eta_min=eta_min,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
elif scheduler_type == "step":
|
| 113 |
+
step_size = int(scheduler_cfg["step_size"])
|
| 114 |
+
gamma = float(scheduler_cfg.get("gamma", 0.1))
|
| 115 |
+
main_scheduler = StepLR(
|
| 116 |
+
optimizer,
|
| 117 |
+
step_size=step_size,
|
| 118 |
+
gamma=gamma,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
else:
|
| 122 |
+
raise ValueError(f"Unsupported scheduler type: {scheduler_type}")
|
| 123 |
+
|
| 124 |
+
if warmup_epochs <= 0:
|
| 125 |
+
return main_scheduler
|
| 126 |
+
|
| 127 |
+
warmup_scheduler = LinearLR(
|
| 128 |
+
optimizer,
|
| 129 |
+
start_factor=warmup_start_factor,
|
| 130 |
+
total_iters=warmup_epochs,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
scheduler = SequentialLR(
|
| 134 |
+
optimizer,
|
| 135 |
+
schedulers=[warmup_scheduler, main_scheduler],
|
| 136 |
+
milestones=[warmup_epochs],
|
| 137 |
+
)
|
| 138 |
+
return scheduler
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def get_checkpoint_model_state(model: SoilFormer) -> Dict[str, torch.Tensor]:
|
| 142 |
+
if hasattr(model, "_checkpoint_state_dict"):
|
| 143 |
+
return model._checkpoint_state_dict() # noqa
|
| 144 |
+
return model.state_dict()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def load_checkpoint_model_state(model: SoilFormer, state_dict: Dict[str, torch.Tensor]) -> None:
|
| 148 |
+
if hasattr(model, "load_weights"):
|
| 149 |
+
payload = {"model_state_dict": state_dict}
|
| 150 |
+
tmp_path = None
|
| 151 |
+
try:
|
| 152 |
+
import tempfile
|
| 153 |
+
with tempfile.NamedTemporaryFile(suffix=".pt", delete=False) as f:
|
| 154 |
+
tmp_path = f.name
|
| 155 |
+
torch.save(payload, tmp_path)
|
| 156 |
+
model.load_weights(tmp_path, map_location="cpu", strict=True)
|
| 157 |
+
finally:
|
| 158 |
+
if tmp_path is not None and os.path.exists(tmp_path):
|
| 159 |
+
os.remove(tmp_path)
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
model.load_state_dict(state_dict, strict=True)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def save_checkpoint(
|
| 166 |
+
checkpoint_path: Path,
|
| 167 |
+
model: SoilFormer,
|
| 168 |
+
optimizer: torch.optim.Optimizer,
|
| 169 |
+
scheduler,
|
| 170 |
+
epoch: int,
|
| 171 |
+
global_step: int,
|
| 172 |
+
config_train: Dict,
|
| 173 |
+
config_model: Dict,
|
| 174 |
+
config_data: Dict,
|
| 175 |
+
) -> None:
|
| 176 |
+
checkpoint = {
|
| 177 |
+
"epoch": epoch,
|
| 178 |
+
"global_step": global_step,
|
| 179 |
+
"model_state_dict": get_checkpoint_model_state(model),
|
| 180 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 181 |
+
"scheduler_state_dict": None if scheduler is None else scheduler.state_dict(),
|
| 182 |
+
"config_train": config_train,
|
| 183 |
+
"config_model": config_model,
|
| 184 |
+
"config_data": config_data,
|
| 185 |
+
}
|
| 186 |
+
checkpoint_path.parent.mkdir(parents=True, exist_ok=True)
|
| 187 |
+
torch.save(checkpoint, checkpoint_path)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def rotate_checkpoints(checkpoint_dir: Path, max_saved_checkpoints: int) -> None:
|
| 191 |
+
checkpoint_paths = sorted(checkpoint_dir.glob("checkpoint_epoch_*.pt"))
|
| 192 |
+
if max_saved_checkpoints is None or max_saved_checkpoints <= 0:
|
| 193 |
+
return
|
| 194 |
+
while len(checkpoint_paths) > max_saved_checkpoints:
|
| 195 |
+
oldest = checkpoint_paths.pop(0)
|
| 196 |
+
oldest.unlink(missing_ok=True)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def compute_loss_from_batch(
|
| 200 |
+
model: SoilFormer,
|
| 201 |
+
batch: Dict,
|
| 202 |
+
device: torch.device,
|
| 203 |
+
dtype: torch.dtype,
|
| 204 |
+
cat_s_bound: Optional[float] = None,
|
| 205 |
+
num_s_bound: Optional[float] = None,
|
| 206 |
+
):
|
| 207 |
+
batch = move_batch_to_device(batch, device=device, float_dtype=dtype)
|
| 208 |
+
|
| 209 |
+
cat_logits_padded, cat_s, valid_class_mask, value_by_nin, s_by_nin, _ = model(
|
| 210 |
+
cat_local_ids=batch["masked_cat_local_ids"],
|
| 211 |
+
numeric_values_by_nin=batch["masked_numeric_values_by_nin"],
|
| 212 |
+
cat_valid_positions=batch["masked_cat_valid_positions"],
|
| 213 |
+
numeric_valid_positions_by_nin=batch["masked_numeric_valid_positions_by_nin"],
|
| 214 |
+
pixel_values=batch["pixel_values"],
|
| 215 |
+
vision_valid_positions=batch["vision_valid_positions"],
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
total_loss, stats = loss_function(
|
| 219 |
+
x_cat=cat_logits_padded,
|
| 220 |
+
s_cat=cat_s,
|
| 221 |
+
y_cat=batch["original_cat_local_ids"],
|
| 222 |
+
loss_mask_cat=batch["cat_loss_mask"],
|
| 223 |
+
valid_class_mask=valid_class_mask,
|
| 224 |
+
x_num=value_by_nin,
|
| 225 |
+
s_num=s_by_nin,
|
| 226 |
+
y_num=batch["original_numeric_values_by_nin"],
|
| 227 |
+
loss_mask_num=batch["numeric_loss_mask_by_nin"],
|
| 228 |
+
reduction="mean",
|
| 229 |
+
cat_s_bound=cat_s_bound,
|
| 230 |
+
num_s_bound=num_s_bound,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
return total_loss, stats
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
@torch.no_grad()
|
| 237 |
+
def evaluate(
|
| 238 |
+
model: SoilFormer,
|
| 239 |
+
dataset: SoilFormerDataset,
|
| 240 |
+
eval_loader,
|
| 241 |
+
device: torch.device,
|
| 242 |
+
dtype: torch.dtype,
|
| 243 |
+
cat_mask_ratio: float,
|
| 244 |
+
num_mask_ratio: float,
|
| 245 |
+
active_mask_seed: int,
|
| 246 |
+
show_tqdm: bool,
|
| 247 |
+
epoch: int,
|
| 248 |
+
cat_s_bound: Optional[float] = None,
|
| 249 |
+
num_s_bound: Optional[float] = None,
|
| 250 |
+
):
|
| 251 |
+
model.eval()
|
| 252 |
+
|
| 253 |
+
totals = {
|
| 254 |
+
"total": 0.0,
|
| 255 |
+
"cat_loss": 0.0,
|
| 256 |
+
"num_loss": 0.0,
|
| 257 |
+
"cat_base": 0.0,
|
| 258 |
+
"num_base": 0.0,
|
| 259 |
+
"cat_acc": 0.0,
|
| 260 |
+
}
|
| 261 |
+
num_batches = 0
|
| 262 |
+
|
| 263 |
+
iterator = eval_loader
|
| 264 |
+
if show_tqdm:
|
| 265 |
+
iterator = tqdm(eval_loader, desc=f"Eval {epoch}", leave=False)
|
| 266 |
+
|
| 267 |
+
for batch_idx, raw_batch in enumerate(iterator):
|
| 268 |
+
mask_seed = int(active_mask_seed + batch_idx)
|
| 269 |
+
masked_batch = dataset.perform_active_mask(
|
| 270 |
+
raw_batch,
|
| 271 |
+
cat_ratio=cat_mask_ratio,
|
| 272 |
+
num_ratio=num_mask_ratio,
|
| 273 |
+
seed=mask_seed,
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
_, stats = compute_loss_from_batch(
|
| 277 |
+
model=model,
|
| 278 |
+
batch=masked_batch,
|
| 279 |
+
device=device,
|
| 280 |
+
dtype=dtype,
|
| 281 |
+
cat_s_bound=cat_s_bound,
|
| 282 |
+
num_s_bound=num_s_bound,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
num_batches += 1
|
| 286 |
+
for key in totals:
|
| 287 |
+
totals[key] += float(stats[key].item())
|
| 288 |
+
|
| 289 |
+
if num_batches == 0:
|
| 290 |
+
raise RuntimeError("Eval dataloader is empty")
|
| 291 |
+
|
| 292 |
+
return {f"eval/{k}": v / num_batches for k, v in totals.items()}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def maybe_init_wandb(config_train: Dict):
|
| 296 |
+
wandb_cfg = config_train["logging"]["wandb"]
|
| 297 |
+
if not bool(wandb_cfg.get("enabled", False)):
|
| 298 |
+
return None
|
| 299 |
+
|
| 300 |
+
if wandb is None:
|
| 301 |
+
raise ImportError("wandb is enabled in config but package is not installed")
|
| 302 |
+
|
| 303 |
+
run = wandb.init(
|
| 304 |
+
project=wandb_cfg["project"],
|
| 305 |
+
entity=wandb_cfg.get("entity"),
|
| 306 |
+
name=wandb_cfg.get("run_name"),
|
| 307 |
+
dir=wandb_cfg.get("dir"),
|
| 308 |
+
config=config_train,
|
| 309 |
+
mode=wandb_cfg.get("mode", "online"),
|
| 310 |
+
)
|
| 311 |
+
return run
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def print_parameter_stats(model):
|
| 315 |
+
total = 0
|
| 316 |
+
trainable = 0
|
| 317 |
+
|
| 318 |
+
for p in model.parameters():
|
| 319 |
+
num = p.numel()
|
| 320 |
+
total += num
|
| 321 |
+
if p.requires_grad:
|
| 322 |
+
trainable += num
|
| 323 |
+
|
| 324 |
+
print("\nParameter statistics:")
|
| 325 |
+
print(f"Total parameters: {total:,}")
|
| 326 |
+
print(f"Trainable parameters: {trainable:,}")
|
| 327 |
+
print(f"Frozen parameters: {total - trainable:,}\n")
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def main():
|
| 331 |
+
parser = argparse.ArgumentParser()
|
| 332 |
+
parser.add_argument("--config", type=str, default="config/config_train.json")
|
| 333 |
+
args = parser.parse_args()
|
| 334 |
+
|
| 335 |
+
config_train = load_json(args.config)
|
| 336 |
+
config_paths = config_train["paths"]
|
| 337 |
+
config_data = load_json(config_paths["config_data_path"])
|
| 338 |
+
config_model = load_json(config_paths["config_model_path"])
|
| 339 |
+
|
| 340 |
+
seed_cfg = config_train["seed"]
|
| 341 |
+
runtime_cfg = config_train["runtime"]
|
| 342 |
+
optim_cfg = config_train["optimization"]
|
| 343 |
+
checkpoint_cfg = config_train["checkpoint"]
|
| 344 |
+
logging_cfg = config_train["logging"]
|
| 345 |
+
loss_cfg = config_train["loss"]
|
| 346 |
+
|
| 347 |
+
set_seed(int(seed_cfg["seed"]), deterministic=bool(seed_cfg.get("deterministic", True)))
|
| 348 |
+
|
| 349 |
+
device = resolve_device(runtime_cfg["device"])
|
| 350 |
+
dtype = get_dtype(config_model.get("dtype", "bfloat16"))
|
| 351 |
+
|
| 352 |
+
output_dir = Path(config_paths["output_dir"])
|
| 353 |
+
checkpoint_dir = output_dir / "checkpoints"
|
| 354 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 355 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 356 |
+
|
| 357 |
+
save_json(config_train, str(output_dir / "config_train.snapshot.json"))
|
| 358 |
+
save_json(config_data, str(output_dir / "config_data.snapshot.json"))
|
| 359 |
+
save_json(config_model, str(output_dir / "config_model.snapshot.json"))
|
| 360 |
+
|
| 361 |
+
dataset = SoilFormerDataset(
|
| 362 |
+
csv_path=config_data["data_csv_path"],
|
| 363 |
+
photo_map_path=config_data["photo_map_path"],
|
| 364 |
+
cat_vocab_path=config_data["cat_vocab_path"],
|
| 365 |
+
numeric_vocab_path=config_data["numeric_vocab_path"],
|
| 366 |
+
numeric_stats_path=config_data["numeric_stats_path"],
|
| 367 |
+
photo_root=config_data["photo_root"],
|
| 368 |
+
image_size=int(config_data["image_size"]),
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
train_loader, eval_loader, train_generator = build_train_eval_dataloaders(
|
| 372 |
+
dataset=dataset,
|
| 373 |
+
train_ratio=float(config_data["train_ratio"]),
|
| 374 |
+
seed=int(config_data["train_eval_split_seed"]),
|
| 375 |
+
batch_size=int(config_data["batch_size"]),
|
| 376 |
+
)
|
| 377 |
+
print("\nSample statistics:")
|
| 378 |
+
print("Train samples:", len(train_loader.dataset))
|
| 379 |
+
print("Eval samples:", len(eval_loader.dataset))
|
| 380 |
+
train_generator.manual_seed(int(seed_cfg["seed"]))
|
| 381 |
+
|
| 382 |
+
model = SoilFormer(config=config_model, device=str(device))
|
| 383 |
+
|
| 384 |
+
resume_path = checkpoint_cfg.get("resume_checkpoint_path")
|
| 385 |
+
if resume_path:
|
| 386 |
+
checkpoint = torch.load(resume_path, map_location="cpu")
|
| 387 |
+
load_checkpoint_model_state(model, checkpoint["model_state_dict"])
|
| 388 |
+
else:
|
| 389 |
+
model.init_weights(std=float(runtime_cfg.get("init_weight_std", 0.02)))
|
| 390 |
+
checkpoint = None
|
| 391 |
+
|
| 392 |
+
print_parameter_stats(model)
|
| 393 |
+
|
| 394 |
+
optimizer = AdamW(
|
| 395 |
+
[p for p in model.parameters() if p.requires_grad],
|
| 396 |
+
lr=float(optim_cfg["lr"]),
|
| 397 |
+
betas=(float(optim_cfg["beta1"]), float(optim_cfg["beta2"])),
|
| 398 |
+
eps=float(optim_cfg["eps"]),
|
| 399 |
+
weight_decay=float(optim_cfg["weight_decay"]),
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
scheduler = build_scheduler(
|
| 403 |
+
optimizer=optimizer,
|
| 404 |
+
scheduler_cfg=optim_cfg.get("scheduler", {"type": "none"})
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
start_epoch = 1
|
| 408 |
+
global_step = 0
|
| 409 |
+
|
| 410 |
+
if checkpoint is not None:
|
| 411 |
+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
|
| 412 |
+
if scheduler is not None and checkpoint.get("scheduler_state_dict") is not None:
|
| 413 |
+
scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
|
| 414 |
+
start_epoch = int(checkpoint["epoch"]) + 1
|
| 415 |
+
global_step = int(checkpoint.get("global_step", 0))
|
| 416 |
+
|
| 417 |
+
wandb_run = maybe_init_wandb(config_train)
|
| 418 |
+
|
| 419 |
+
num_epochs = int(runtime_cfg["num_epochs"])
|
| 420 |
+
show_tqdm = bool(logging_cfg.get("tqdm", True))
|
| 421 |
+
cat_mask_ratio = float(config_data["cat_mask_ratio"])
|
| 422 |
+
num_mask_ratio = float(config_data["num_mask_ratio"])
|
| 423 |
+
active_mask_seed = int(config_data["active_mask_seed"])
|
| 424 |
+
max_grad_norm = optim_cfg.get("max_grad_norm")
|
| 425 |
+
epochs_per_save = int(checkpoint_cfg["epochs_per_save"])
|
| 426 |
+
max_saved_checkpoints = int(checkpoint_cfg["max_saved_checkpoints"])
|
| 427 |
+
|
| 428 |
+
for epoch in range(start_epoch, num_epochs + 1):
|
| 429 |
+
model.train()
|
| 430 |
+
|
| 431 |
+
epoch_totals = {
|
| 432 |
+
"total": 0.0,
|
| 433 |
+
"cat_loss": 0.0,
|
| 434 |
+
"num_loss": 0.0,
|
| 435 |
+
"cat_base": 0.0,
|
| 436 |
+
"num_base": 0.0,
|
| 437 |
+
"cat_acc": 0.0,
|
| 438 |
+
}
|
| 439 |
+
num_batches = 0
|
| 440 |
+
|
| 441 |
+
iterator = train_loader
|
| 442 |
+
if show_tqdm:
|
| 443 |
+
iterator = tqdm(train_loader, desc=f"Train {epoch}", leave=True)
|
| 444 |
+
|
| 445 |
+
for batch_idx, raw_batch in enumerate(iterator):
|
| 446 |
+
global_step += 1
|
| 447 |
+
mask_seed = int(active_mask_seed + epoch * 1_000_000 + batch_idx)
|
| 448 |
+
masked_batch = dataset.perform_active_mask(
|
| 449 |
+
raw_batch,
|
| 450 |
+
cat_ratio=cat_mask_ratio,
|
| 451 |
+
num_ratio=num_mask_ratio,
|
| 452 |
+
seed=mask_seed,
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
optimizer.zero_grad(set_to_none=True)
|
| 456 |
+
|
| 457 |
+
total_loss, stats = compute_loss_from_batch(
|
| 458 |
+
model=model,
|
| 459 |
+
batch=masked_batch,
|
| 460 |
+
device=device,
|
| 461 |
+
dtype=dtype,
|
| 462 |
+
cat_s_bound=loss_cfg.get("cat_s_bound", None),
|
| 463 |
+
num_s_bound=loss_cfg.get("num_s_bound", None),
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
total_loss.backward()
|
| 467 |
+
if max_grad_norm is not None:
|
| 468 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), float(max_grad_norm))
|
| 469 |
+
optimizer.step()
|
| 470 |
+
|
| 471 |
+
num_batches += 1
|
| 472 |
+
for key in epoch_totals:
|
| 473 |
+
epoch_totals[key] += float(stats[key].item())
|
| 474 |
+
|
| 475 |
+
current_lr = float(optimizer.param_groups[0]["lr"])
|
| 476 |
+
train_step_log = {
|
| 477 |
+
"train/step_total": float(stats["total"].item()),
|
| 478 |
+
"train/step_cat_loss": float(stats["cat_loss"].item()),
|
| 479 |
+
"train/step_num_loss": float(stats["num_loss"].item()),
|
| 480 |
+
"train/step_cat_acc": float(stats["cat_acc"].item()),
|
| 481 |
+
"train/lr": current_lr,
|
| 482 |
+
"epoch": epoch,
|
| 483 |
+
"global_step": global_step,
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
if wandb_run is not None:
|
| 487 |
+
wandb.log(train_step_log, step=global_step)
|
| 488 |
+
|
| 489 |
+
if show_tqdm:
|
| 490 |
+
iterator.set_postfix(
|
| 491 |
+
loss=f"{train_step_log['train/step_total']:.4f}",
|
| 492 |
+
lr=f"{current_lr:.3e}",
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
if num_batches == 0:
|
| 496 |
+
raise RuntimeError("Train dataloader is empty")
|
| 497 |
+
|
| 498 |
+
train_epoch_log = {f"train/{k}": v / num_batches for k, v in epoch_totals.items()}
|
| 499 |
+
train_epoch_log["train/lr_epoch_end"] = float(optimizer.param_groups[0]["lr"])
|
| 500 |
+
train_epoch_log["epoch"] = epoch
|
| 501 |
+
train_epoch_log["global_step"] = global_step
|
| 502 |
+
|
| 503 |
+
eval_log = evaluate(
|
| 504 |
+
model=model,
|
| 505 |
+
dataset=dataset,
|
| 506 |
+
eval_loader=eval_loader,
|
| 507 |
+
device=device,
|
| 508 |
+
dtype=dtype,
|
| 509 |
+
cat_mask_ratio=cat_mask_ratio,
|
| 510 |
+
num_mask_ratio=num_mask_ratio,
|
| 511 |
+
active_mask_seed=active_mask_seed,
|
| 512 |
+
show_tqdm=show_tqdm,
|
| 513 |
+
epoch=epoch,
|
| 514 |
+
cat_s_bound=loss_cfg.get("cat_s_bound", None),
|
| 515 |
+
num_s_bound=loss_cfg.get("num_s_bound", None),
|
| 516 |
+
)
|
| 517 |
+
eval_log["epoch"] = epoch
|
| 518 |
+
eval_log["global_step"] = global_step
|
| 519 |
+
|
| 520 |
+
merged_log = {}
|
| 521 |
+
merged_log.update(train_epoch_log)
|
| 522 |
+
merged_log.update(eval_log)
|
| 523 |
+
|
| 524 |
+
print(json.dumps(merged_log, ensure_ascii=False))
|
| 525 |
+
|
| 526 |
+
if wandb_run is not None:
|
| 527 |
+
wandb.log(merged_log, step=global_step)
|
| 528 |
+
|
| 529 |
+
if scheduler is not None:
|
| 530 |
+
scheduler.step()
|
| 531 |
+
|
| 532 |
+
if epochs_per_save > 0 and epoch % epochs_per_save == 0:
|
| 533 |
+
checkpoint_path = checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
|
| 534 |
+
save_checkpoint(
|
| 535 |
+
checkpoint_path=checkpoint_path,
|
| 536 |
+
model=model,
|
| 537 |
+
optimizer=optimizer,
|
| 538 |
+
scheduler=scheduler,
|
| 539 |
+
epoch=epoch,
|
| 540 |
+
global_step=global_step,
|
| 541 |
+
config_train=config_train,
|
| 542 |
+
config_model=config_model,
|
| 543 |
+
config_data=config_data,
|
| 544 |
+
)
|
| 545 |
+
rotate_checkpoints(checkpoint_dir, max_saved_checkpoints)
|
| 546 |
+
|
| 547 |
+
if wandb_run is not None:
|
| 548 |
+
wandb.finish()
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
if __name__ == "__main__":
|
| 552 |
+
main()
|
modelling/utils.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# utils.py
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F # noqa
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class GroupedMLP(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Batched per-variable MLP for a fixed n_in bucket.
|
| 16 |
+
|
| 17 |
+
Input: X [B, V, n_in]
|
| 18 |
+
Output: Y [B, V, n_out]
|
| 19 |
+
|
| 20 |
+
Per-variable weights (NOT shared across V):
|
| 21 |
+
- 1-layer: W [V, n_out, n_in], b [V, n_out]
|
| 22 |
+
- 2-layer: W1 [V, mid, n_in], b1 [V, mid]
|
| 23 |
+
W2 [V, n_out, mid], b2 [V, n_out]
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
n_var: int,
|
| 29 |
+
n_in: int,
|
| 30 |
+
n_out: int,
|
| 31 |
+
middle_size: Optional[int] = None,
|
| 32 |
+
bias: bool = True,
|
| 33 |
+
):
|
| 34 |
+
super().__init__()
|
| 35 |
+
|
| 36 |
+
self.n_var = int(n_var)
|
| 37 |
+
self.n_in = int(n_in)
|
| 38 |
+
self.n_out = int(n_out)
|
| 39 |
+
self.middle_size = None if middle_size is None else int(middle_size)
|
| 40 |
+
self.bias = bias
|
| 41 |
+
|
| 42 |
+
if self.middle_size is None:
|
| 43 |
+
self.W = nn.Parameter(torch.empty(self.n_var, self.n_out, self.n_in))
|
| 44 |
+
|
| 45 |
+
if bias:
|
| 46 |
+
self.b = nn.Parameter(torch.empty(self.n_var, self.n_out))
|
| 47 |
+
else:
|
| 48 |
+
self.register_parameter("b", None)
|
| 49 |
+
|
| 50 |
+
self.W1 = self.b1 = self.W2 = self.b2 = None
|
| 51 |
+
|
| 52 |
+
else:
|
| 53 |
+
mid = self.middle_size
|
| 54 |
+
|
| 55 |
+
self.W1 = nn.Parameter(torch.empty(self.n_var, mid, self.n_in))
|
| 56 |
+
self.W2 = nn.Parameter(torch.empty(self.n_var, self.n_out, mid))
|
| 57 |
+
|
| 58 |
+
if bias:
|
| 59 |
+
self.b1 = nn.Parameter(torch.empty(self.n_var, mid))
|
| 60 |
+
self.b2 = nn.Parameter(torch.empty(self.n_var, self.n_out))
|
| 61 |
+
else:
|
| 62 |
+
self.register_parameter("b1", None)
|
| 63 |
+
self.register_parameter("b2", None)
|
| 64 |
+
|
| 65 |
+
self.W = self.b = None
|
| 66 |
+
|
| 67 |
+
def init_weights(self, std: float = 0.02) -> None:
|
| 68 |
+
"""
|
| 69 |
+
Initialize weights manually.
|
| 70 |
+
"""
|
| 71 |
+
if self.middle_size is None:
|
| 72 |
+
nn.init.normal_(self.W, std=std)
|
| 73 |
+
if self.bias:
|
| 74 |
+
nn.init.zeros_(self.b)
|
| 75 |
+
else:
|
| 76 |
+
nn.init.normal_(self.W1, std=std)
|
| 77 |
+
nn.init.normal_(self.W2, std=std)
|
| 78 |
+
|
| 79 |
+
if self.bias:
|
| 80 |
+
nn.init.zeros_(self.b1)
|
| 81 |
+
nn.init.zeros_(self.b2)
|
| 82 |
+
|
| 83 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
if x.dim() != 3:
|
| 85 |
+
raise ValueError(f"Expected x [B,V,n_in], got {tuple(x.shape)}")
|
| 86 |
+
|
| 87 |
+
B, V, I = x.shape
|
| 88 |
+
|
| 89 |
+
if V != self.n_var or I != self.n_in:
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Shape mismatch: expected V={self.n_var}, n_in={self.n_in}; got V={V}, n_in={I}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if self.middle_size is None:
|
| 95 |
+
y = torch.einsum("bvi,voi->bvo", x, self.W)
|
| 96 |
+
if self.bias:
|
| 97 |
+
y = y + self.b.unsqueeze(0)
|
| 98 |
+
return y
|
| 99 |
+
|
| 100 |
+
h = torch.einsum("bvi,vmi->bvm", x, self.W1)
|
| 101 |
+
if self.bias:
|
| 102 |
+
h = h + self.b1.unsqueeze(0)
|
| 103 |
+
|
| 104 |
+
h = F.gelu(h)
|
| 105 |
+
|
| 106 |
+
y = torch.einsum("bvm,vom->bvo", h, self.W2)
|
| 107 |
+
if self.bias:
|
| 108 |
+
y = y + self.b2.unsqueeze(0)
|
| 109 |
+
|
| 110 |
+
return y
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def get_dtype(dtype: Optional[str]) -> torch.dtype:
|
| 114 |
+
dtype_str = (dtype or "bfloat16").lower()
|
| 115 |
+
dtype_map = {
|
| 116 |
+
"bfloat16": torch.bfloat16,
|
| 117 |
+
"float16": torch.float16,
|
| 118 |
+
"float32": torch.float32,
|
| 119 |
+
}
|
| 120 |
+
if dtype_str not in dtype_map:
|
| 121 |
+
raise ValueError(f"Unsupported dtype={dtype}. Choose from {list(dtype_map.keys())}")
|
| 122 |
+
return dtype_map[dtype_str]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_json(path: str):
|
| 126 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 127 |
+
return json.load(f)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def save_json(obj: Dict, path: str) -> None:
|
| 131 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 132 |
+
json.dump(obj, f, ensure_ascii=False, indent=2) # noqa
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch~=2.10.0
|
| 2 |
+
numpy~=2.3.4
|
| 3 |
+
wandb~=0.25.1
|
| 4 |
+
tqdm~=4.67.1
|
| 5 |
+
pandas~=2.3.3
|
| 6 |
+
requests~=2.32.5
|
| 7 |
+
pillow~=12.0.0
|
| 8 |
+
torchvision~=0.25.0
|
| 9 |
+
safetensors~=0.7.0
|
| 10 |
+
transformers~=5.2.0
|
resources/arch.png
ADDED
|
Git LFS Details
|