5L.py: Implementation Guide for the Low-Rank Autoregressive Model
Introduction
This guide provides comprehensive instructions for setting up and operating the 5L.py script. The script is a versatile tool designed for both training a low-rank autoregressive model and using it for text generation (decoding). It is engineered with a robust feature set that includes fresh-start safety to prevent accidental overwrites, automatic mixed precision (AMP) for performance, out-of-memory (OOM) backoff for stability, and progressive block growth for efficient training on longer sequences. For inference, it offers a rich suite of sampling controls to precisely shape the generated output. This document is intended for machine learning engineers and practitioners who need to deploy, train, and run inference with this script.
- Prerequisites and Initial Setup
Properly configuring your environment and installing the necessary dependencies is the essential first step to ensure the script functions correctly. This section outlines the required setup before you can begin training or inference.
Environment Configuration
The script's behavior is influenced by a key environment variable and the availability of specific hardware.
- Tokenizer ID: The script uses a specific tokenizer from the Hugging Face Hub. You can configure this using the TOKENIZER_ID environment variable. If this variable is not set, it defaults to Qwen/Qwen3-235B-A22B-Thinking-2507. To use a different compatible tokenizer, set the variable in your shell:
- Hardware: The script is optimized for CUDA-enabled GPUs. It automatically detects the available device (torch.device("cuda" if torch.cuda.is_available() else "cpu")). Critical performance features, including Automatic Mixed Precision (AMP) and FP8 computation, are only available when running on a CUDA device.
Python Dependencies
The script relies on several core Python libraries. Ensure the following are installed in your Python environment:
- torch
- datasets
- transformers
- tqdm
With the environment configured and dependencies in place, you are ready to explore the core concepts of the model and its training architecture.
- Core Concepts and Architecture
Understanding the script's foundational concepts is crucial for making effective use of its command-line arguments. Grasping the model presets and checkpointing strategy will allow you to control training and inference with greater precision and intent.
Model Configuration Presets
The script includes several built-in architecture presets to simplify model configuration. These presets define the model's size and complexity and can be selected using the --preset argument during training or inference.
Preset Name Dimensions (d) Layers Heads Rank (r) small 512 8 16 64 smallx2 512 16 16 64 base 768 12 24 96
Checkpointing and State Management
The script employs a robust, time-based checkpointing strategy to ensure training progress is saved reliably.
- Time-Based Saving: Checkpoints are saved at a regular time interval defined by the --save_every_sec argument, not after a fixed number of steps. This ensures that progress is saved consistently, regardless of training speed.
- Checkpoint Contents: Each saved checkpoint is a .pt file containing a comprehensive snapshot of the training state for full reproducibility, as detailed in the save_ckpt function:
- Model state dictionaries for the core encoder and ar head.
- The opt optimizer state dictionary.
- The scaler state dictionary for mixed precision training.
- Model configuration (cfg).
- Training metadata, including the current step, total tokens seen (seen_tok), and wall_time.
- RNG states for reproducibility: py_state (Python) and torch_state (PyTorch).
The script distinguishes between two ways of loading a checkpoint, each serving a different purpose:
- Resuming (--resume): This is used to continue a previously interrupted training job. It restores the entire training state, including the model weights, optimizer state, gradient scaler, step count, and tokens seen. The training process continues exactly where it left off. When a directory is provided, the script intelligently locates the most recently modified valid checkpoint, even ignoring temporary files from incomplete saves.
- Warm-Starting (--warmstart_from): This is used to initialize a new model's weights from a previously saved checkpoint, typically a final.pt file. It only loads weights for layers with matching shapes and does not restore the optimizer or other training states. This is ideal for knowledge transfer, such as starting a larger model's training with the weights of a smaller, pre-trained one.
Understanding these concepts provides the foundation for the practical application of the script's training capabilities.
- Training the Model (train command)
The train command is the entry point for initiating, controlling, and optimizing the model training process. This section provides a complete guide to its command-line arguments and demonstrates its use in common training scenarios.
Training Command-Line Arguments
The train command offers a wide range of arguments to configure the training run. These are summarized below.
Argument Type Description Model & Data --preset str Selects a model architecture preset (small, smallx2, base). --rank int Overrides the rank r of the low-rank attention mechanism. --block int Sets the sequence length (block size) for each training step. Defaults to 576, which is also the starting point of the default growth plan. --source str Specifies the Hugging Face dataset to use for training. Defaults to cerebras/SlimPajama-627B. --x2 bool Doubles the number of layers defined by the selected --preset. If used with --warmstart_from, it doubles the layers of the loaded checkpoint's configuration. Training Control --target_tokens int The total number of tokens the model should see before training concludes. If unset, it is automatically calculated using a Chinchilla-style rule (25x the total number of enabled parameters in both the core model and the AR head). --steps int A maximum number of training steps to run. The training will stop when either this or target_tokens is reached. Performance & Precision --amp bool Enables Automatic Mixed Precision (AMP) using bfloat16 or float16 for faster training. --fp8-only bool Attempts to use FP8 (float8_e4m3fn) for computation, offering the highest performance on compatible hardware. --fp8-fallback bool If FP8 is requested but unsupported, this allows the script to fall back to bfloat16/float16 instead of raising an error. Checkpointing & State --save_every_sec int The interval in seconds for saving checkpoints. Defaults to 24 hours (24 * 3600). --save_dir str The directory where checkpoints will be saved. Defaults to ckpts_joint. --resume str Path to a checkpoint file or directory to resume a training run from. --warmstart_from str Path to a checkpoint (e.g., final.pt) to initialize model weights from for a new run. --fresh bool Starts training from scratch, ignoring any existing checkpoints or warm-start paths. Progressive Block Growth --auto_grow bool Enables the progressive block growth feature. --grow_plan str A comma-separated list of block sizes to progressively train on (e.g., "576,768,1024"). --grow_every_steps int The number of steps to wait before attempting to increase the block size according to the plan.
Practical Training Scenarios
Here are some common examples of how to use the train command.
Scenario 1: Starting a Fresh Training Run
To start a new training run from scratch using the base preset with AMP enabled, use the --fresh flag to ensure no previous state is loaded.
python 5L.py train
--preset base
--amp
--save_dir ./ckpts_base_model
--fresh
This command initializes a new base model, saves checkpoints to ./ckpts_base_model, enables mixed-precision training, and explicitly prevents loading from any prior checkpoints.
Scenario 2: Resuming an Interrupted Job
If a training job is stopped, you can resume it seamlessly using the --resume flag, pointing it to the last saved checkpoint or its directory.
python 5L.py train
--preset base
--amp
--save_dir ./ckpts_base_model
--resume ./ckpts_base_model/step00123456.pt
The script will automatically load the model, optimizer, step count, and total seen tokens from the checkpoint and continue training.
Scenario 3: Warm-Starting from a Previous Model
To transfer learned knowledge to a new, larger model (e.g., by doubling the layers with --x2), you can warm-start from a completed checkpoint. This is a powerful technique for efficiently training larger architectures.
python 5L.py train
--preset base
--x2
--amp
--save_dir ./ckpts_base_x2
--warmstart_from ./ckpts_base_model/final.pt
This command initializes a new model with twice the layers of the base preset. It then loads all weights from final.pt where the tensor shapes match, providing a strong starting point for the new training run. The optimizer and step count are initialized from scratch.
Feature Deep Dive
Automatic Mixed Precision (AMP) and FP8
The script provides granular control over numerical precision to balance performance and accuracy:
- --amp: This is the standard way to enable AMP, which uses bfloat16 on compatible hardware or float16 otherwise. This significantly speeds up training with minimal impact on stability.
- --fp8-only: For cutting-edge performance, this flag attempts to use the float8_e4m3fn data type. This requires a recent version of PyTorch and compatible hardware (e.g., NVIDIA H100 series).
- --fp8-fallback: This flag acts as a safety net. Using --fp8-only on incompatible hardware without --fp8-fallback will raise a RuntimeError. The --fp8-fallback flag prevents this error by allowing the script to proceed with bfloat16 or float16 instead.
Progressive Block Growth
This feature is designed to improve training efficiency and stability by starting with smaller sequence lengths and gradually increasing them.
- --auto_grow: This master switch enables the feature.
- --grow_plan: Defines the sequence of block sizes to use. For example, --grow_plan "576,768,1024" instructs the script to start at a block size of 576, then attempt to grow to 768, and finally to 1024.
- --grow_every_steps: This sets the cadence for growth attempts. Every time this number of steps is completed, the script will try to move to the next block size in the plan.
Crucially, this feature is OOM-safe. If a growth attempt fails due to an out-of-memory error, the script automatically catches the error, reverts to the previous stable block size, and continues training without crashing. This makes it a robust tool for maximizing sequence length safely.
With a model successfully trained, the next step is to use it for generating text.
- Generating Text (infer command)
Once a model has been trained and checkpointed, the infer command is used to generate text. This command loads a trained model and provides a powerful set of decoding parameters to control the properties of the generated output, from deterministic and factual to random and creative.
Inference Command-Line Arguments
The infer command is controlled by the following arguments, which allow you to specify the model, prompt, and decoding strategy.
Argument Type Description Model and Prompt --mode str Required. Specifies the decoding mode. Must be set to ar. --ckpt str Required. Path to the model checkpoint (.pt) file to use for inference. --preset str The model preset used during training. Defaults to small. --prompt str Required. The initial text prompt to seed the generation. --max_new int The maximum number of new tokens to generate. Defaults to 120. --fp8-only bool Attempts to use FP8 autocasting during decoding for better performance. Decoding Strategy --greedy bool Forces greedy decoding, where the most probable token is always chosen. This disables all other sampling parameters. --temperature float Controls randomness. Values > 1.0 make output more random; values < 1.0 make it more deterministic. Defaults to 1.0. --top_k int Filters the vocabulary to the k most likely next tokens. A value of 0 disables it. --top_p float Filters the vocabulary to the smallest set of tokens whose cumulative probability exceeds p. This is an adaptive method that can select a variable number of tokens at each step, unlike the fixed-size filtering of top_k. Defaults to 1.0 (disabled). --min_p float Filters out tokens with a probability lower than p. Defaults to 0.0 (disabled). Repetition Control --repetition_penalty float Penalizes tokens that have already appeared. Values > 1.0 discourage repetition. Defaults to 1.0. --presence_penalty float Penalizes tokens based on their presence in the context, regardless of frequency. Defaults to 0.0. --frequency_penalty float Penalizes tokens based on how frequently they appear in the context. Defaults to 0.0. --penalty_last_n int Sets the number of recent tokens to consider for applying repetition penalties. Defaults to 64. --no_repeat_ngram_size int Prevents the generation of any n-gram that has already appeared. A value of 0 disables it.
Decoding and Sampling Strategies
You can combine the arguments above to implement various decoding strategies.
Greedy vs. Sampling
Using the --greedy flag results in deterministic output. The model will always select the token with the absolute highest probability at each step. This is useful for tasks requiring factual and predictable completions. When --greedy is not used, the model samples from a probability distribution, which can be shaped by the parameters below.
Controlling Randomness and Diversity
- --temperature: This is the most direct way to control creativity. A temperature of 0.7 will produce more focused and coherent text, while 1.2 will produce more surprising and diverse text.
- --top_k, --top_p, --min_p: These parameters filter the pool of candidate tokens before sampling.
- --top_k 50 restricts sampling to the 50 most likely tokens.
- --top_p 0.95 restricts sampling to the smallest set of tokens whose cumulative probability exceeds 95%. This adaptive method can select a different number of tokens at each step.
- You can combine these to create a more constrained sampling space.
Managing Repetition and Penalties
- --repetition_penalty, --presence_penalty, --frequency_penalty: These three arguments dynamically lower the probability (logits) of tokens that have already been generated within the context window defined by --penalty_last_n. repetition_penalty applies a multiplicative penalty, while presence and frequency apply additive ones. They are effective tools for reducing loops and improving the coherence of longer generations.
- --no_repeat_ngram_size: This provides a hard constraint. For example, --no_repeat_ngram_size 3 completely forbids the model from generating any 3-gram (sequence of three tokens) that has already appeared in the output.
Practical Inference Example
The following command loads a checkpoint and generates text using a balanced sampling configuration designed for creative and coherent output.
python 5L.py infer
--mode ar
--ckpt ./ckpts_base_model/final.pt
--prompt "In a world where magic is powered by logic, the most powerful sorcerers are"
--max_new 256
--temperature 0.75
--top_k 50
--top_p 0.95
--repetition_penalty 1.1
This configuration encourages creative but coherent output. The temperature of 0.75 reduces randomness slightly, top_k and top_p together create a high-quality pool of candidate tokens, and the repetition_penalty discourages the model from getting stuck in loops.
This guide has equipped you with the knowledge to effectively configure, train, and operate the 5L.py script for your machine learning projects.
Technical Whitepaper: Architecture and Operation of a Low-Rank Autoregressive Model
1.0 Introduction
This whitepaper provides a definitive technical analysis of a parameter-efficient, autoregressive-only model architecture and its operationally resilient training and inference framework, 5L.py. Its purpose is to deconstruct the system's key components, from its core architectural decisions to its robust operational framework, for an audience of machine learning practitioners and researchers.
The system is distinguished by several key architectural and functional highlights that will be explored in detail. These include a novel LowRankMHA attention mechanism for computational efficiency, the use of Attention with Linear Biases (ALiBi) to dynamically encode positional information, a robust training framework featuring progressive block growth, and a flexible inference engine with multiple sampling strategies. This analysis will proceed by first examining the static model architecture before delving into the dynamic processes of training and inference.
2.0 Model Architecture
The strategic design of the model is centered on a decoder-only, autoregressive transformer architecture, optimized for efficient training and generation. This design philosophy is manifested in its core components: the Encoder, which forms the main transformer stack; the ARHead, which serves as the final output layer; and a specialized low-rank attention mechanism that is central to the model's efficiency.
2.2 Overall Structure
At a high level, the model is composed of two primary modules that work in sequence: the Encoder (the core transformer stack) and the ARHead (the final projection layer).
The function of the Encoder is to process input token sequences and produce a sequence of high-dimensional hidden state representations. It is composed of an nn.Embedding layer to convert token IDs into vectors, followed by a series of identical transformer Block modules, and concluding with a final nn.LayerNorm for output stabilization.
Following the Encoder, the ARHead performs the final, critical step of the autoregressive process. It consists of a single nn.Linear layer that takes the final hidden state from the Encoder and projects it to the full vocabulary dimension (VOCAB), producing the logits used to predict the next token in the sequence.
2.3 Core Component: Low-Rank Multi-Head Attention (LowRankMHA)
The central innovation within the transformer block is the LowRankMHA module, a parameter-efficient variant of standard multi-head attention. Its forward pass begins by projecting the input into query, key, and value tensors via the self.q, self.k, and self.v layers.
The key innovation is the low-rank projection step. Before attention scores are computed, the query and key tensors are multiplied by a shared, trainable parameter matrix self.U of shape (dk, r), where dk is the dimension per head and r is the rank. This projects the heads into a lower-dimensional space. The primary benefit of this projection is computational: the expensive query-key matrix multiplication, typically of complexity O(N^2 * dk), becomes O(N^2 * r). This reduction in the inner dimension is the main source of efficiency gains, especially when the rank r is significantly smaller than the head dimension dk.
After the attention scores are computed and applied to the value vectors, the concatenated low-rank outputs from all heads are passed through a final projection layer (self.proj), which maps the combined results back to the model's primary hidden dimension d.
2.4 Relative Positional Encoding: Attention with Linear Biases (ALiBi)
The model eschews traditional positional embeddings in favor of Attention with Linear Biases (ALiBi), a relative positioning scheme implemented in the alibi_bias function. A bias matrix is dynamically constructed based on the distance between tokens (j - i) and a set of head-specific slopes. These slopes are not arbitrary but are derived from a geometric series (specifically 2**(-2**-(log2(n)-3))), a core design element of ALiBi that ensures more distant tokens receive exponentially smaller biases. This bias is added directly to the attention scores before the softmax operation, allowing the model to weigh token interactions based on their proximity.
2.5 Transformer Block (Block)
The Block module encapsulates a standard transformer layer structure, arranged in a pre-normalization configuration for improved training stability. Each block consists of a pre-normalization layer (self.ln1) followed by the LowRankMHA module, and a second pre-normalization layer (self.ln2) followed by a feed-forward network (self.ff). The feed-forward network is a two-layer MLP that expands the hidden dimension d by a factor of 4, with a ReLU non-linearity between the layers. Residual connections are applied after both the multi-head attention and feed-forward sub-layers. Having detailed the model's static structure, we now turn to the dynamic framework responsible for training it.
3.0 Training Framework
The training script is strategically engineered for robustness, efficiency, and scalability. Its design incorporates features such as automatic mixed precision for faster computation, dynamic block sizing for optimal hardware utilization, and fault-tolerant checkpointing to ensure the integrity of long-running training jobs.
3.2 Data Ingestion and Processing
The data pipeline is managed by the token_stream function, which leverages the datasets library to load a streaming dataset, defaulting to cerebras/SlimPajama-627B. This approach avoids downloading the entire dataset locally. The stream is shuffled with a buffer to ensure randomized training examples, and the function yields individual tokens. To ensure consistent sequence termination, an End-of-Sentence (EOS) token is automatically appended to each example if one is not already present.
3.3 Optimization and Mixed-Precision Training
The training process is driven by the torch.optim.AdamW optimizer, configured with separate learning rates for the core model (LR_CORE = 5e-5) and the autoregressive head (LR_HEAD = 2e-4). The strategic rationale for this is that the ARHead, a simple single-layer projection, can and should learn much faster than the deep, complex Encoder core. This allows for rapid adaptation of the output layer without destabilizing the foundational representations learned by the main model body.
To accelerate training, the amp helper provides robust Automatic Mixed Precision (AMP) capabilities. The auto_amp_dtype function automatically selects the optimal data type based on hardware support and user flags, in order of preference: torch.float8_e4m3fn (if enabled via --fp8-only), torch.bfloat16, or torch.float16. To prevent exploding gradients, torch.nn.utils.clip_grad_norm is applied to clip the gradient norm at a maximum value of 1.0.
3.4 Dynamic Training Strategies
The framework includes two key dynamic strategies to maximize training efficiency and resilience:
- Progressive Block Growth: The --auto_grow feature enables curriculum-based training. The process starts with a smaller block size and periodically increases it according to a predefined plan (--grow_plan) every grow_every_steps.
- Out-of-Memory (OOM) Resilience: The main training loop is wrapped in a try...except block to catch CUDA "out of memory" errors. Upon catching such an error, the script automatically reduces the current BLOCK size by half (to a minimum of 128) and retries the step, preventing crashes from memory spikes.
3.5 Persistence and State Management
A robust checkpointing system ensures training progress is not lost. Checkpoints are saved on a time-based cadence determined by --save_every_sec (defaulting to 24 hours). As defined in save_ckpt, each checkpoint is a comprehensive snapshot containing model state dictionaries (core, ar), optimizer and scaler states, model configuration (cfg), training metadata (step, seen_tok, wall_time), and the Python (py_state) and PyTorch (torch_state) random number generator states for perfectly reproducible resumption. The script seamlessly handles both resuming training (--resume) and warm-starting from existing weights (--warmstart_from).
3.6 Target Token Calculation
Following Chinchilla-style methodology, if a --target_tokens value is not provided, the script calculates a compute-optimal target by multiplying the total number of enabled model parameters by 25. This calculation, per the _count_enabled_params function, deliberately includes all trainable parameters from both the core model and the ar_h head, demonstrating a rigorous application of scaling laws to the entire model system. With a comprehensive understanding of how the model is trained, the next section will explore how it is used for text generation.
4.0 Inference and Text Generation
The script's inference capabilities are designed for flexibility, providing a rich set of command-line options that control the decoding process, from greedy selection to sophisticated penalized sampling.
4.2 Autoregressive Decoding Process
The ar_decode function orchestrates text generation. The process begins with an initialization phase where the input prompt is tokenized and passed through the core model with use_cache=True. This generates the initial hidden states and Key-Value caches (kvs). The subsequent iterative generation loop is highly efficient: for each new token, the model processes only the last token's hidden state and uses the updated kvs, avoiding redundant computation over the full sequence.
4.3 Sampling and Filtering Strategies
The _filter_top_k_top_p_min_p function implements several core sampling strategies, which can be controlled via command-line arguments.
Parameter CLI Argument Description Greedy Decoding --greedy Overrides sampling methods. Deterministically selects the token with the highest probability. Temperature --temperature Controls randomness. Values < 1.0 sharpen the distribution; values > 1.0 flatten it. Top-K Sampling --top_k Restricts sampling to the k most likely tokens. 0 disables this filter. Top-P Sampling --top_p Restricts sampling to the smallest set of tokens whose cumulative probability exceeds p. Min-P Sampling --min_p Filters out tokens with a probability below the min_p threshold.
4.4 Generation Penalty Mechanisms
To mitigate common issues like repetition, the script provides several penalty mechanisms that directly modify the logits before sampling.
The _apply_rep_presence_frequency function modifies logits based on tokens that have appeared in a recent history window, controlled by --penalty_last_n. It applies three distinct penalties:
- Repetition Penalty (--repetition_penalty): Modifies the logits of previously seen tokens. It divides positive logits by the penalty factor but multiplies negative logits, preventing unlikely tokens from becoming more likely.
- Presence Penalty (--presence_penalty): Subtracts a fixed penalty from the logits of any token that has appeared at least once.
- Frequency Penalty (--frequency_penalty): Subtracts a penalty proportional to how many times a token has already appeared.
Additionally, the _apply_no_repeat_ngram function, controlled by --no_repeat_ngram_size, prevents the model from generating n-grams of a specified size that have already appeared in the output by setting their logits to negative infinity. This transitions us to a practical guide on the command-line interface.
5.0 Operational Guide: Command-Line Interface
This section serves as a practical guide to using the script via its command-line interface (CLI), which is organized into two main subcommands: train and infer.
5.2 Prerequisites and General Configuration
The script utilizes the Qwen/Qwen3-235B-A22B-Thinking-2507 tokenizer by default, as specified by the TOKENIZER_ID global variable. For convenience, several preset model configurations (small, smallx2, base) are available via the --preset argument.
5.3 Training (train command)
The train subcommand initiates and manages the model training process. The table below outlines its most critical arguments.
Argument Function --preset Selects a predefined model architecture (small, smallx2, base). --rank Overrides the preset rank (r) for the LowRankMHA module. --block Sets the sequence length (block size) for each training step. --x2 Doubles the number of layers specified by the preset or a warm-started model, enabling rapid scaling experiments. --source Specifies the streaming dataset to use for training. --save_dir The directory where checkpoints will be saved. --resume Path to a checkpoint to resume training from. --warmstart_from Path to a checkpoint for warm-starting (initializing weights). --fresh Starts training from scratch, ignoring any existing checkpoints. --amp Enables standard automatic mixed-precision training (BF16/FP16). --fp8-only Attempts to use float8_e4m3fn for training. --auto_grow Enables the progressive block growth training strategy.
5.4 Inference (infer command)
The infer subcommand is used to generate text from a trained model checkpoint. Its key arguments are detailed below.
Argument Function --ckpt Required. Path to the model checkpoint file to load for inference. --prompt Required. The initial text prompt to seed the generation. --max_new The maximum number of new tokens to generate. --temperature Sets the sampling temperature (default: 1.0). --top_k Enables Top-K sampling with the specified value k. --repetition_penalty Applies a penalty to repeated tokens to discourage redundancy. --penalty_last_n Defines the token history window for penalty calculations (default: 64).
This operational guide provides the necessary information to effectively utilize the script.
6.0 Conclusion
This document has presented a detailed technical overview of a low-rank autoregressive model and its associated operational script, defined by design choices aimed at achieving a balance between performance, efficiency, and robustness.
The architectural novelty of the LowRankMHA mechanism and the computational efficiency of ALiBi provide a parameter-efficient foundation. This is supported by a sophisticated training framework that emphasizes dynamic resource management through progressive block growth and OOM resilience, as well as state persistence via reliable checkpointing. Paired with a flexible and feature-rich inference engine, the script serves as a robust and flexible blueprint for the end-to-end development and deployment of high-performance autoregressive language models.