Zenyx-v2-base (Nano-Titan Architecture)
Zenyx-v2-base is a highly optimized, 85-million parameter language model built on the custom Nano-Titan architecture. Engineered entirely in JAX/Flax for pure BF16 execution on TPU v5e-8 hardware, this model is designed to push the boundaries of parameter efficiency. The pretraining corpus is heavily skewed towards reasoning, utilizing a meticulously balanced stream of mathematics, programming languages, and high-quality educational text targeting a 200 billion token lifecycle.
Architectural Specifications
The Nano-Titan architecture employs parameter sharing and advanced attention mechanisms to achieve an effective depth far exceeding its raw parameter count.
| Component | Specification |
|---|---|
| Total Parameters | ~85M (Unique weights) |
| Effective Depth | 32 Layers (8 unique blocks × 4 recurrences) |
| Embedding Dimension ($D_{model}$) | 576 |
| Attention Mechanism | Multi-Head Latent Attention (MLA) |
| Heads | 9 Query Heads, 3 KV Heads |
| MLP Architecture | ConvSwiGLU (Hidden Dim: 1536) |
| Context Window | 8192 Tokens (YaRN-scaled RoPE) |
| Vocabulary Size | 32,768 |
Core Innovations
Multi-Token Prediction (MTP) The model leverages a 3-head Multi-Token Prediction objective during pretraining. By predicting multiple future tokens simultaneously with decaying loss weights (1.0, 0.3, 0.1), the network develops stronger long-term contextual representations and improves sample efficiency without increasing inference latency.
Multi-Head Latent Attention (MLA) Memory overhead during training and inference is drastically reduced via MLA. The key-value cache is compressed into a 128-dimensional latent space, while queries are projected through a 384-dimensional latent bottleneck. This structural optimization allows for extended context windows within constrained hardware environments.
ConvSwiGLU Feedforward Networks Standard feedforward blocks are replaced with Convolutional SwiGLU networks. A 1D convolution (kernel size 3) is applied to the gating mechanism prior to the SiLU activation, imparting a localized inductive bias that enhances the processing of structured data formats like code and mathematics.
Recurrent Depth via Block Sharing
To maximize parameter utility, the network consists of 8 distinct TitanBlock modules that are sequentially recurred 4 times. This yields an effective representational depth of 32 layers while maintaining an 85M parameter memory footprint.
YaRN-Scaled Rotary Position Embeddings Extrapolation to an 8192-token context is stabilized using YaRN (Yet another RoPE extensioN). The scaling factor is set to 32.0 with dynamic wavelength interpolation, ensuring uniform attention distribution across both local and extended sequence lengths.
Training Data & Optimization
The dataset pipeline is fully streaming and dynamic, reading sharded Parquet files directly from the Hugging Face Hub to prevent local disk bottlenecks.
| Domain | Mixture Weight | Source Datasets | Filtration Criteria |
|---|---|---|---|
| Mathematics | 45% | finemath-4plus, infiwebmath |
Statically routed across 3+ and 4+ streams |
| Code | 35% | starcoderdata |
24 strictly filtered languages, min length 20 chars |
| English / Edu | 20% | fineweb-edu |
Educational score $\ge$ 3.0 |
Training utilizes a Warmup-Stable-Decay (WSD) learning rate schedule, peaking at 3e-4 before executing a cosine decay to 3e-5 over the final 18,000 steps. Global gradient accumulation is managed via chunked cross-entropy to strictly cap per-chip HBM usage at ~12.2 GB.
Evaluation Metrics (Checkpoint: Step 96,000)
The model is currently evaluated against a dedicated high-quality educational validation split encompassing 128 micro-batches.
| Metric | Value |
|---|---|
| Best Validation Loss | 3.7206 |
| Recent Training Loss (Avg) | ~2.10 |
| Throughput | ~200,000 tokens/sec |
| Tokens Processed | ~200 Billion |
Implementation & Usage
Until the weights are explicitly converted to PyTorch safetensors and mapped to AutoModelForCausalLM, the model must be instantiated using the native JAX/Flax module defined in the training environment.
import jax
import jax.numpy as jnp
from flax import serialization
from transformers import PreTrainedTokenizerFast
from zenyx_model import ZenyxV2 # Import your model class here
# 1. Initialize Tokenizer
tokenizer = PreTrainedTokenizerFast.from_pretrained("Arko007/zenyx-v2-tokenizer")
# 2. Instantiate Architecture
model = ZenyxV2(
vocab_size=32768,
d_model=576,
n_heads=9,
n_kv_heads=3,
head_dim=64,
hidden_dim=1536,
n_unique_blocks=8,
n_recurrences=4,
max_seq_len=8192,
kv_latent=128,
q_latent=384,
mtp_heads=3,
dropout_rate=0.0
)
# 3. Load Checkpoint Weights
# Note: Ensure params_step96000.msgpack is downloaded from the Hub
with open("params_step96000.msgpack", "rb") as f:
params_bytes = f.read()
dummy_input = jnp.ones((1, 8192), dtype=jnp.int32)
rng = jax.random.PRNGKey(0)
variables = model.init(rng, input_ids=dummy_input, train=False)
loaded_params = serialization.from_bytes(variables["params"], params_bytes)
# 4. Forward Pass (Inference)
input_text = "The derivative of f(x) = x^2 is"
input_ids = tokenizer(input_text, return_tensors="jax")["input_ids"]
# Model returns a list of MTP logits; index 0 is the primary next-token prediction
logits = model.apply({"params": loaded_params}, input_ids=input_ids, train=False)[0]
next_token = jnp.argmax(logits[0, -1, :])
print(tokenizer.decode([int(next_token)]))
Limitations and Biases
Zenyx-v2-base functions strictly as a foundational pre-trained model. It has not been subjected to instruction fine-tuning, Reinforcement Learning from Human Feedback (RLHF), or Direct Preference Optimization (DPO). As a direct result of operating exclusively on next-token and multi-token prediction objectives, the model does not possess conversational alignment. Outputs may exhibit hallucinations, lack structural coherence in zero-shot dialogue scenarios, or reflect inherent biases present within the constituent datasets (FineMath, StarCoderData, and FineWeb-Edu). Downstream utilization requires domain-specific fine-tuning and the implementation of robust safety guardrails prior to production deployment.
Hardware and Environmental Setup
The pre-training lifecycle was executed on Google TPU v5e-8 topology utilizing pure bfloat16 (BF16) arithmetic. The distributed execution framework leverages jax.pmap for cross-device data parallelism. The software stack mandates JAX $\ge$ 0.4.16, optimized via libtpu, with gradient transformations managed through Optax. The asynchronous data pipeline is heavily optimized for high-throughput streaming directly from Hugging Face Parquet shards to circumvent local disk I/O bottlenecks.
License
The Zenyx-v2-base model weights, architecture code, and associated tokenizer are distributed under the Apache License 2.0. You are free to use, modify, distribute, and utilize this model for commercial applications, provided you adhere to the attribution and liability limitations specified within the license framework.
Citation
If you utilize the Zenyx-v2-base model, the Nano-Titan architecture, or the associated multi-token prediction training methodologies in your research or production environments, please cite the repository using the following BibTeX entry:
@misc{zenyx2026,
author = {Anamitra},
title = {Zenyx-v2-base: A Parameter-Efficient Nano-Titan Language Model},
year = {2026},
publisher = {Hugging Face},
howpublished = {\url{[https://huggingface.co/Arko007/zenyx-v2-base](https://huggingface.co/Arko007/zenyx-v2-base)}},
note = {Pre-trained on TPU v5e-8 using JAX/Flax}
}