download
raw
2.17 kB
#!/bin/bash
set -euo pipefail
source .venv/bin/activate
python -m compileall -q src/ || { echo "Syntax error"; exit 1; }
MODEL_DIR="${MODEL_DIR:-./model}"
DATA_CONFIG="${DATA_CONFIG:-./config/data_config.json}"
FRESH_VAL_SEED="${FRESH_VAL_SEED:-0}"
SKIP_TRAIN="${SKIP_TRAIN:-0}"
TRAIN_ITERS="${TRAIN_ITERS:-300}"
BATCH_SIZE="${BATCH_SIZE:-2}"
GRAD_ACC="${GRAD_ACC:-8}"
NUM_LAYERS="${NUM_LAYERS:-16}"
LEARNING_RATE="${LEARNING_RATE:-2e-5}"
MAX_SEQ_LENGTH="${MAX_SEQ_LENGTH:-1024}"
VAL_BATCHES="${VAL_BATCHES:--1}"
TEST_BATCHES="${TEST_BATCHES:--1}"
STEPS_PER_REPORT="${STEPS_PER_REPORT:-25}"
STEPS_PER_EVAL="${STEPS_PER_EVAL:-100}"
SAVE_EVERY="${SAVE_EVERY:-200}"
MASK_PROMPT="${MASK_PROMPT:-1}"
ADAPTER_PATH="${ADAPTER_PATH:-./adapters}"
mkdir -p logs "$ADAPTER_PATH" data
python src/prepare_data.py --config "$DATA_CONFIG" --fresh-val-seed "$FRESH_VAL_SEED" > logs/data_prep.json
RUN_TS="$(date +%Y%m%d-%H%M%S)"
LOG_PATH="logs/train-${RUN_TS}.log"
MASK_ARG=""
if [[ "$MASK_PROMPT" == "1" ]]; then
MASK_ARG="--mask-prompt"
fi
if [[ "$SKIP_TRAIN" == "1" ]]; then
mlx_lm.lora \
--model "$MODEL_DIR" \
--data ./data \
--adapter-path "" \
--test \
--batch-size "$BATCH_SIZE" \
--test-batches "$TEST_BATCHES" \
--max-seq-length "$MAX_SEQ_LENGTH" \
$MASK_ARG | tee "$LOG_PATH"
else
mlx_lm.lora \
--model "$MODEL_DIR" \
--data ./data \
--train \
--test \
--adapter-path "$ADAPTER_PATH" \
--batch-size "$BATCH_SIZE" \
--iters "$TRAIN_ITERS" \
--grad-accumulation-steps "$GRAD_ACC" \
--num-layers "$NUM_LAYERS" \
--learning-rate "$LEARNING_RATE" \
--val-batches "$VAL_BATCHES" \
--test-batches "$TEST_BATCHES" \
--max-seq-length "$MAX_SEQ_LENGTH" \
--steps-per-report "$STEPS_PER_REPORT" \
--steps-per-eval "$STEPS_PER_EVAL" \
--save-every "$SAVE_EVERY" \
$MASK_ARG | tee "$LOG_PATH"
fi
python src/parse_mlx_lora_log.py "$LOG_PATH"
python - <<'PY'
import json
stats=json.load(open('data/stats.json'))
print(f"METRIC train_examples={stats['train_count']}")
print(f"METRIC valid_examples={stats['valid_count']}")
print(f"METRIC test_examples={stats['test_count']}")
PY

Xet Storage Details

Size:
2.17 kB
·
Xet hash:
5e39d1f5091446981770240d24652278aa705b6fda24d12ed9a981cb4ca70f8a

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.