dwahdany commited on
Commit
6914420
·
verified ·
1 Parent(s): 58a4549

Add format section explaining safetensors compatibility

Browse files
Files changed (1) hide show
  1. README.md +9 -0
README.md CHANGED
@@ -66,14 +66,23 @@ Pre-trained [TabDPT](https://github.com/JesseCresswell/tfm-mia) model weights tr
66
  | CTR Correlation | 0.830 | 0.830 | 0.827 |
67
  | CTR R² | 0.726 | 0.730 | 0.725 |
68
 
 
 
 
 
69
  ## Usage
70
 
71
  ```python
72
  from tabdpt import TabDPTClassifier
73
  from huggingface_hub import hf_hub_download
74
 
 
75
  path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")
 
 
76
  clf = TabDPTClassifier(model_weight_path=path)
77
  clf.fit(X_train, y_train)
78
  preds = clf.predict(X_test)
79
  ```
 
 
 
66
  | CTR Correlation | 0.830 | 0.830 | 0.827 |
67
  | CTR R² | 0.726 | 0.730 | 0.725 |
68
 
69
+ ## Format
70
+
71
+ These checkpoints were converted from PyTorch Lightning `.ckpt` files (which include optimizer state, ~295MB each) to SafeTensors format (model weights only, ~103MB each). This is the same format used by the official `Layer6/TabDPT` release. The `tabdpt` package natively loads SafeTensors via the `model_weight_path` argument — no extra conversion needed.
72
+
73
  ## Usage
74
 
75
  ```python
76
  from tabdpt import TabDPTClassifier
77
  from huggingface_hub import hf_hub_download
78
 
79
+ # Download once (cached afterwards)
80
  path = hf_hub_download("dwahdany/TabDPT", "production_seed42.safetensors")
81
+
82
+ # Use exactly like the default model
83
  clf = TabDPTClassifier(model_weight_path=path)
84
  clf.fit(X_train, y_train)
85
  preds = clf.predict(X_test)
86
  ```
87
+
88
+ Works identically with `TabDPTRegressor`.