new

Get trending papers in your email inbox!

Subscribe

Daily Papers

byAK and the research community

Apr 14

FlashAttention-4: Algorithm and Kernel Pipelining Co-Design for Asymmetric Hardware Scaling

Attention, as a core layer of the ubiquitous Transformer architecture, is the bottleneck for large language models and long-context applications. While FlashAttention-3 optimized attention for Hopper GPUs through asynchronous execution and warp specialization, it primarily targets the H100 architecture. The AI industry has rapidly transitioned to deploying Blackwell-based systems such as the B200 and GB200, which exhibit fundamentally different performance characteristics due to asymmetric hardware scaling: tensor core throughput doubles while other functional units (shared memory bandwidth, exponential units) scale more slowly or remain unchanged. We develop several techniques to address these shifting bottlenecks on Blackwell GPUs: (1) redesigned pipelines that exploit fully asynchronous MMA operations and larger tile sizes, (2) software-emulated exponential and conditional softmax rescaling that reduces non-matmul operations, and (3) leveraging tensor memory and the 2-CTA MMA mode to reduce shared memory traffic and atomic adds in the backward pass. We demonstrate that our method, FlashAttention-4, achieves up to 1.3times speedup over cuDNN 9.13 and 2.7times over Triton on B200 GPUs with BF16, reaching up to 1613 TFLOPs/s (71% utilization). Beyond algorithmic innovations, we implement FlashAttention-4 entirely in CuTe-DSL embedded in Python, achieving 20-30times faster compile times compared to traditional C++ template-based approaches while maintaining full expressivity.

  • 6 authors
·
Mar 5

TPLA: Tensor Parallel Latent Attention for Efficient Disaggregated Prefill \& Decode Inference

Multi-Head Latent Attention (MLA), introduced in DeepSeek-V2, compresses key-value states into a low-rank latent vector, caching only this vector to reduce memory. In tensor parallelism (TP), however, attention heads are computed across multiple devices, and each device must load the full cache, eroding the advantage of MLA over Grouped Query Attention (GQA). We propose Tensor-Parallel Latent Attention (TPLA): a scheme that partitions both the latent representation and each head's input dimension across devices, performs attention independently per shard, and then combines results with an all-reduce. TPLA preserves the benefits of a compressed KV cache while unlocking TP efficiency. Unlike Grouped Latent Attention (GLA), every head in TPLA still leverages the full latent representation, maintaining stronger representational capacity. TPLA is drop-in compatible with models pre-trained using MLA: it supports MLA-style prefilling and enables efficient tensor-parallel decoding without retraining. Applying simple orthogonal transforms -- e.g., the Hadamard transform or PCA -- before TP slicing further mitigates cross-shard interference, yielding minimal accuracy degradation. By reducing the per-device KV cache for DeepSeek-V3 and Kimi-K2, we achieve 1.79x and 1.93x speedups, respectively, at a 32K-token context length while maintaining performance on commonsense and LongBench benchmarks. TPLA can be implemented with FlashAttention-3, enabling practical end-to-end acceleration.

  • 7 authors
·
Aug 21, 2025 2

Tensor Decomposition Networks for Fast Machine Learning Interatomic Potential Computations

SO(3)-equivariant networks are the dominant models for machine learning interatomic potentials (MLIPs). The key operation of such networks is the Clebsch-Gordan (CG) tensor product, which is computationally expensive. To accelerate the computation, we develop tensor decomposition networks (TDNs) as a class of approximately equivariant networks in which CG tensor products are replaced by low-rank tensor decompositions, such as the CANDECOMP/PARAFAC (CP) decomposition. With the CP decomposition, we prove (i) a uniform bound on the induced error of SO(3)-equivariance, and (ii) the universality of approximating any equivariant bilinear map. To further reduce the number of parameters, we propose path-weight sharing that ties all multiplicity-space weights across the O(L^3) CG paths into a single shared parameter set without compromising equivariance, where L is the maximum angular degree. The resulting layer acts as a plug-and-play replacement for tensor products in existing networks, and the computational complexity of tensor products is reduced from O(L^6) to O(L^4). We evaluate TDNs on PubChemQCR, a newly curated molecular relaxation dataset containing 105 million DFT-calculated snapshots. We also use existing datasets, including OC20, and OC22. Results show that TDNs achieve competitive performance with dramatic speedup in computations. Our code is publicly available as part of the AIRS library (https://github.com/divelab/AIRS/tree/main/OpenMol/TDN{https://github.com/divelab/AIRS/}).

  • 9 authors
·
Jul 1, 2025

The Price of Freedom: Exploring Expressivity and Runtime Tradeoffs in Equivariant Tensor Products

E(3)-equivariant neural networks have demonstrated success across a wide range of 3D modelling tasks. A fundamental operation in these networks is the tensor product, which interacts two geometric features in an equivariant manner to create new features. Due to the high computational complexity of the tensor product, significant effort has been invested to optimize the runtime of this operation. For example, Luo et al. (2024) recently proposed the Gaunt tensor product (GTP) which promises a significant speedup. In this work, we provide a careful, systematic analysis of a number of tensor product operations. In particular, we emphasize that different tensor products are not performing the same operation. The reported speedups typically come at the cost of expressivity. We introduce measures of expressivity and interactability to characterize these differences. In addition, we realized the original implementation of GTP can be greatly simplified by directly using a spherical grid at no cost in asymptotic runtime. This spherical grid approach is faster on our benchmarks and in actual training of the MACE interatomic potential by 30%. Finally, we provide the first systematic microbenchmarks of the various tensor product operations. We find that the theoretical runtime guarantees can differ wildly from empirical performance, demonstrating the need for careful application-specific benchmarking. Code is available at https://github.com/atomicarchitects/PriceofFreedom.

  • 4 authors
·
Jun 16, 2025

Enabling Efficient Equivariant Operations in the Fourier Basis via Gaunt Tensor Products

Developing equivariant neural networks for the E(3) group plays an important role in modeling 3D data across real-world applications. Enforcing this equivariance primarily involves the tensor products of irreducible representations (irreps). However, the computational complexity of such operations increases significantly as higher-order tensors are used. In this work, we propose a systematic approach to substantially accelerate the computation of the tensor products of irreps. We mathematically connect the commonly used Clebsch-Gordan coefficients to the Gaunt coefficients, which are integrals of products of three spherical harmonics. Through Gaunt coefficients, the tensor product of irreps becomes equivalent to the multiplication between spherical functions represented by spherical harmonics. This perspective further allows us to change the basis for the equivariant operations from spherical harmonics to a 2D Fourier basis. Consequently, the multiplication between spherical functions represented by a 2D Fourier basis can be efficiently computed via the convolution theorem and Fast Fourier Transforms. This transformation reduces the complexity of full tensor products of irreps from O(L^6) to O(L^3), where L is the max degree of irreps. Leveraging this approach, we introduce the Gaunt Tensor Product, which serves as a new method to construct efficient equivariant operations across different model architectures. Our experiments on the Open Catalyst Project and 3BPA datasets demonstrate both the increased efficiency and improved performance of our approach.

  • 3 authors
·
Jan 18, 2024

Facet: highly efficient E(3)-equivariant networks for interatomic potentials

Computational materials discovery is limited by the high cost of first-principles calculations. Machine learning (ML) potentials that predict energies from crystal structures are promising, but existing methods face computational bottlenecks. Steerable graph neural networks (GNNs) encode geometry with spherical harmonics, respecting atomic symmetries -- permutation, rotation, and translation -- for physically realistic predictions. Yet maintaining equivariance is difficult: activation functions must be modified, and each layer must handle multiple data types for different harmonic orders. We present Facet, a GNN architecture for efficient ML potentials, developed through systematic analysis of steerable GNNs. Our innovations include replacing expensive multi-layer perceptrons (MLPs) for interatomic distances with splines, which match performance while cutting computational and memory demands. We also introduce a general-purpose equivariant layer that mixes node information via spherical grid projection followed by standard MLPs -- faster than tensor products and more expressive than linear or gate layers. On the MPTrj dataset, Facet matches leading models with far fewer parameters and under 10% of their training compute. On a crystal relaxation task, it runs twice as fast as MACE models. We further show SevenNet-0's parameters can be reduced by over 25% with no accuracy loss. These techniques enable more than 10x faster training of large-scale foundation models for ML potentials, potentially reshaping computational materials discovery.

  • 9 authors
·
Sep 10, 2025

CompactifAI: Extreme Compression of Large Language Models using Quantum-Inspired Tensor Networks

Large Language Models (LLMs) such as ChatGPT and LlaMA are advancing rapidly in generative Artificial Intelligence (AI), but their immense size poses significant challenges, such as huge training and inference costs, substantial energy demands, and limitations for on-site deployment. Traditional compression methods such as pruning, distillation, and low-rank approximation focus on reducing the effective number of neurons in the network, while quantization focuses on reducing the numerical precision of individual weights to reduce the model size while keeping the number of neurons fixed. While these compression methods have been relatively successful in practice, there is no compelling reason to believe that truncating the number of neurons is an optimal strategy. In this context, this paper introduces CompactifAI, an innovative LLM compression approach using quantum-inspired Tensor Networks that focuses on the model's correlation space instead, allowing for a more controlled, refined and interpretable model compression. Our method is versatile and can be implemented with - or on top of - other compression techniques. As a benchmark, we demonstrate that a combination of CompactifAI with quantization allows to reduce a 93% the memory size of LlaMA 7B, reducing also 70% the number of parameters, accelerating 50% the training and 25% the inference times of the model, and just with a small accuracy drop of 2% - 3%, going much beyond of what is achievable today by other compression techniques. Our methods also allow to perform a refined layer sensitivity profiling, showing that deeper layers tend to be more suitable for tensor network compression, which is compatible with recent observations on the ineffectiveness of those layers for LLM performance. Our results imply that standard LLMs are, in fact, heavily overparametrized, and do not need to be large at all.

  • 18 authors
·
Jan 25, 2024

ABQ-LLM: Arbitrary-Bit Quantized Inference Acceleration for Large Language Models

Large Language Models (LLMs) have revolutionized natural language processing tasks. However, their practical application is constrained by substantial memory and computational demands. Post-training quantization (PTQ) is considered an effective method to accelerate LLM inference. Despite its growing popularity in LLM model compression, PTQ deployment faces two major challenges. First, low-bit quantization leads to performance degradation. Second, restricted by the limited integer computing unit type on GPUs, quantized matrix operations with different precisions cannot be effectively accelerated. To address these issues, we introduce a novel arbitrary-bit quantization algorithm and inference framework, ABQ-LLM. It achieves superior performance across various quantization settings and enables efficient arbitrary-precision quantized inference on the GPU. ABQ-LLM introduces several key innovations: (1) a distribution correction method for transformer blocks to mitigate distribution differences caused by full quantization of weights and activations, improving performance at low bit-widths. (2) the bit balance strategy to counteract performance degradation from asymmetric distribution issues at very low bit-widths (e.g., 2-bit). (3) an innovative quantization acceleration framework that reconstructs the quantization matrix multiplication of arbitrary precision combinations based on BTC (Binary TensorCore) equivalents, gets rid of the limitations of INT4/INT8 computing units. ABQ-LLM can convert each component bit width gain into actual acceleration gain, maximizing performance under mixed precision(e.g., W6A6, W2A8). Based on W2*A8 quantization configuration on LLaMA-7B model, it achieved a WikiText2 perplexity of 7.59 (2.17downarrow vs 9.76 in AffineQuant). Compared to SmoothQuant, we realized 1.6times acceleration improvement and 2.7times memory compression gain.

  • 9 authors
·
Aug 16, 2024

PolySAE: Modeling Feature Interactions in Sparse Autoencoders via Polynomial Decoding

Sparse autoencoders (SAEs) have emerged as a promising method for interpreting neural network representations by decomposing activations into sparse combinations of dictionary atoms. However, SAEs assume that features combine additively through linear reconstruction, an assumption that cannot capture compositional structure: linear models cannot distinguish whether "Starbucks" arises from the composition of "star" and "coffee" features or merely their co-occurrence. This forces SAEs to allocate monolithic features for compound concepts rather than decomposing them into interpretable constituents. We introduce PolySAE, which extends the SAE decoder with higher-order terms to model feature interactions while preserving the linear encoder essential for interpretability. Through low-rank tensor factorization on a shared projection subspace, PolySAE captures pairwise and triple feature interactions with small parameter overhead (3% on GPT2). Across four language models and three SAE variants, PolySAE achieves an average improvement of approximately 8% in probing F1 while maintaining comparable reconstruction error, and produces 2-10times larger Wasserstein distances between class-conditional feature distributions. Critically, learned interaction weights exhibit negligible correlation with co-occurrence frequency (r = 0.06 vs. r = 0.82 for SAE feature covariance), suggesting that polynomial terms capture compositional structure, such as morphological binding and phrasal composition, largely independent of surface statistics.

  • 5 authors
·
Feb 1 2

Efficient and Scalable Density Functional Theory Hamiltonian Prediction through Adaptive Sparsity

Hamiltonian matrix prediction is pivotal in computational chemistry, serving as the foundation for determining a wide range of molecular properties. While SE(3) equivariant graph neural networks have achieved remarkable success in this domain, their substantial computational cost--driven by high-order tensor product (TP) operations--restricts their scalability to large molecular systems with extensive basis sets. To address this challenge, we introduce SPHNet, an efficient and scalable equivariant network, that incorporates adaptive SParsity into Hamiltonian prediction. SPHNet employs two innovative sparse gates to selectively constrain non-critical interaction combinations, significantly reducing tensor product computations while maintaining accuracy. To optimize the sparse representation, we develop a Three-phase Sparsity Scheduler, ensuring stable convergence and achieving high performance at sparsity rates of up to 70%. Extensive evaluations on QH9 and PubchemQH datasets demonstrate that SPHNet achieves state-of-the-art accuracy while providing up to a 7x speedup over existing models. Beyond Hamiltonian prediction, the proposed sparsification techniques also hold significant potential for improving the efficiency and scalability of other SE(3) equivariant networks, further broadening their applicability and impact. Our code can be found at https://github.com/microsoft/SPHNet.

  • 10 authors
·
Feb 3, 2025

FlashHead: Efficient Drop-In Replacement for the Classification Head in Language Model Inference

Language models are increasingly adopting smaller architectures optimized for consumer devices. In this setting, inference efficiency is the primary constraint. Meanwhile, vocabulary sizes continue to grow rapidly, making the classification head a critical bottleneck that accounts for up to 60\% of model parameters, and 50\% of inference compute. We introduce FlashHead, the first efficient drop-in replacement for the dense classification head that is training-free and hardware-friendly. FlashHead builds on principles from information retrieval, reframing that computation at the output head as a retrieval problem rather than a dense classification over the full vocabulary. FlashHead introduces four key innovations: (1) a balanced clustering scheme that structures vocabulary partitions into compact hardware-efficient tensors, (2) extending multiprobe retrieval to language model heads, enabling thousands of clusters to be scored in parallel, (3) a novel inference-time sampling mechanism that extends retrieval beyond top tokens, enabling probabilistic sampling across the full vocabulary, and (4) selective quantization, enabling effective low-bit computation in the head. Experiments on Llama-3.2, Gemma-3, and Qwen-3 show that FlashHead delivers model-level inference speedups of up to 1.75x which maintaining output accuracy compared to the original head. By overcoming the classification head bottleneck, FlashHead establishes a new benchmark for efficient inference and removes a key barrier to developing smaller, capable models for consumer hardware.

  • 5 authors
·
Mar 15

How to Capture Higher-order Correlations? Generalizing Matrix Softmax Attention to Kronecker Computation

In the classical transformer attention scheme, we are given three n times d size matrices Q, K, V (the query, key, and value tokens), and the goal is to compute a new n times d size matrix D^{-1} exp(QK^top) V where D = diag( exp(QK^top) {bf 1}_n ). In this work, we study a generalization of attention which captures triple-wise correlations. This generalization is able to solve problems about detecting triple-wise connections that were shown to be impossible for transformers. The potential downside of this generalization is that it appears as though computations are even more difficult, since the straightforward algorithm requires cubic time in n. However, we show that in the bounded-entry setting (which arises in practice, and which is well-studied in both theory and practice), there is actually a near-linear time algorithm. More precisely, we show that bounded entries are both necessary and sufficient for quickly performing generalized computations: bullet On the positive side, if all entries of the input matrices are bounded above by o(sqrt[3]{log n}) then we show how to approximate the ``tensor-type'' attention matrix in n^{1+o(1)} time. bullet On the negative side, we show that if the entries of the input matrices may be as large as Omega(sqrt[3]{log n}), then there is no algorithm that runs faster than n^{3-o(1)} (assuming the Strong Exponential Time Hypothesis from fine-grained complexity theory). We also show that our construction, algorithms, and lower bounds naturally generalize to higher-order tensors and correlations. Interestingly, the higher the order of the tensors, the lower the bound on the entries needs to be for an efficient algorithm. Our results thus yield a natural tradeoff between the boundedness of the entries, and order of the tensor one may use for more expressive, efficient attention computation.

  • 2 authors
·
Oct 6, 2023

Semantic MapNet: Building Allocentric Semantic Maps and Representations from Egocentric Views

We study the task of semantic mapping - specifically, an embodied agent (a robot or an egocentric AI assistant) is given a tour of a new environment and asked to build an allocentric top-down semantic map ("what is where?") from egocentric observations of an RGB-D camera with known pose (via localization sensors). Towards this goal, we present SemanticMapNet (SMNet), which consists of: (1) an Egocentric Visual Encoder that encodes each egocentric RGB-D frame, (2) a Feature Projector that projects egocentric features to appropriate locations on a floor-plan, (3) a Spatial Memory Tensor of size floor-plan length x width x feature-dims that learns to accumulate projected egocentric features, and (4) a Map Decoder that uses the memory tensor to produce semantic top-down maps. SMNet combines the strengths of (known) projective camera geometry and neural representation learning. On the task of semantic mapping in the Matterport3D dataset, SMNet significantly outperforms competitive baselines by 4.01-16.81% (absolute) on mean-IoU and 3.81-19.69% (absolute) on Boundary-F1 metrics. Moreover, we show how to use the neural episodic memories and spatio-semantic allocentric representations build by SMNet for subsequent tasks in the same space - navigating to objects seen during the tour("Find chair") or answering questions about the space ("How many chairs did you see in the house?"). Project page: https://vincentcartillier.github.io/smnet.html.

  • 6 authors
·
Oct 2, 2020

Orbital Graph Convolutional Neural Network for Material Property Prediction

Material representations that are compatible with machine learning models play a key role in developing models that exhibit high accuracy for property prediction. Atomic orbital interactions are one of the important factors that govern the properties of crystalline materials, from which the local chemical environments of atoms is inferred. Therefore, to develop robust machine learningmodels for material properties prediction, it is imperative to include features representing such chemical attributes. Here, we propose the Orbital Graph Convolutional Neural Network (OGCNN), a crystal graph convolutional neural network framework that includes atomic orbital interaction features that learns material properties in a robust way. In addition, we embedded an encoder-decoder network into the OGCNN enabling it to learn important features among basic atomic (elemental features), orbital-orbital interactions, and topological features. We examined the performance of this model on a broad range of crystalline material data to predict different properties. We benchmarked the performance of the OGCNN model with that of: 1) the crystal graph convolutional neural network (CGCNN), 2) other state-of-the-art descriptors for material representations including Many-body Tensor Representation (MBTR) and the Smooth Overlap of Atomic Positions (SOAP), and 3) other conventional regression machine learning algorithms where different crystal featurization methods have been used. We find that OGCNN significantly outperforms them. The OGCNN model with high predictive accuracy can be used to discover new materials among the immense phase and compound spaces of materials

  • 6 authors
·
Aug 14, 2020

Detect Anything in Real Time: From Single-Prompt Segmentation to Multi-Class Detection

Recent advances in vision-language modeling have produced promptable detection and segmentation systems that accept arbitrary natural language queries at inference time. Among these, SAM3 achieves state-of-the-art accuracy by combining a ViT-H/14 backbone with cross-modal transformer decoding and learned object queries. However, SAM3 processes a single text prompt per forward pass. Detecting N categories requires N independent executions, each dominated by the 439M-parameter backbone. We present Detect Anything in Real Time (DART), a training-free framework that converts SAM3 into a real-time multi-class detector by exploiting a structural invariant: the visual backbone is class-agnostic, producing image features independent of the text prompt. This allows the backbone computation to be shared between all classes, reducing its cost from O(N) to O(1). Combined with batched multi-class decoding, detection-only inference, and TensorRT FP16 deployment, these optimizations yield 5.6x cumulative speedup at 3 classes, scaling to 25x at 80 classes, without modifying any model weight. On COCO val2017 (5,000 images, 80 classes), DART achieves 55.8 AP at 15.8 FPS (4 classes, 1008x1008) on a single RTX 4080, surpassing purpose-built open-vocabulary detectors trained on millions of box annotations. For extreme latency targets, adapter distillation with a frozen encoder-decoder achieves 38.7 AP with a 13.9 ms backbone. Code and models are available at https://github.com/mkturkcan/DART.

  • 1 authors
·
Mar 11

QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving

Quantization can accelerate large language model (LLM) inference. Going beyond INT8 quantization, the research community is actively exploring even lower precision, such as INT4. Nonetheless, state-of-the-art INT4 quantization techniques only accelerate low-batch, edge LLM inference, failing to deliver performance gains in large-batch, cloud-based LLM serving. We uncover a critical issue: existing INT4 quantization methods suffer from significant runtime overhead (20-90%) when dequantizing either weights or partial sums on GPUs. To address this challenge, we introduce QoQ, a W4A8KV4 quantization algorithm with 4-bit weight, 8-bit activation, and 4-bit KV cache. QoQ stands for quattuor-octo-quattuor, which represents 4-8-4 in Latin. QoQ is implemented by the QServe inference library that achieves measured speedup. The key insight driving QServe is that the efficiency of LLM serving on GPUs is critically influenced by operations on low-throughput CUDA cores. Building upon this insight, in QoQ algorithm, we introduce progressive quantization that can allow low dequantization overhead in W4A8 GEMM. Additionally, we develop SmoothAttention to effectively mitigate the accuracy degradation incurred by 4-bit KV quantization. In the QServe system, we perform compute-aware weight reordering and take advantage of register-level parallelism to reduce dequantization latency. We also make fused attention memory-bound, harnessing the performance gain brought by KV4 quantization. As a result, QServe improves the maximum achievable serving throughput of Llama-3-8B by 1.2x on A100, 1.4x on L40S; and Qwen1.5-72B by 2.4x on A100, 3.5x on L40S, compared to TensorRT-LLM. Remarkably, QServe on L40S GPU can achieve even higher throughput than TensorRT-LLM on A100. Thus, QServe effectively reduces the dollar cost of LLM serving by 3x. Code is available at https://github.com/mit-han-lab/qserve.

  • 7 authors
·
May 7, 2024

Empowering 1000 tokens/second on-device LLM prefilling with mllm-NPU

On-device large language models (LLMs) are catalyzing novel mobile applications such as UI task automation and personalized email auto-reply, without giving away users' private data. However, on-device LLMs still suffer from unacceptably long inference latency, especially the time to first token (prefill stage) due to the need of long context for accurate, personalized content generation, as well as the lack of parallel computing capacity of mobile CPU/GPU. To enable practical on-device LLM, we present mllm-NPU, the first-of-its-kind LLM inference system that efficiently leverages on-device Neural Processing Unit (NPU) offloading. Essentially, mllm-NPU is an algorithm-system co-design that tackles a few semantic gaps between the LLM architecture and contemporary NPU design. Specifically, it re-constructs the prompt and model in three levels: (1) At prompt level, it divides variable-length prompts into multiple fixed-sized chunks while maintaining data dependencies; (2) At tensor level, it identifies and extracts significant outliers to run on the CPU/GPU in parallel with minimal overhead; (3) At block level, it schedules Transformer blocks in an out-of-order manner to the CPU/GPU and NPU based on their hardware affinity and sensitivity to accuracy. Compared to competitive baselines, mllm-NPU achieves 22.4x faster prefill speed and 30.7x energy savings on average, and up to 32.8x speedup in an end-to-end real-world application. For the first time, mllm-NPU achieves more than 1,000 tokens/sec prefilling for a billion-sized model (Qwen1.5-1.8B), paving the way towards practical on-device LLM.

  • 7 authors
·
Jul 8, 2024

FlashSchNet: Fast and Accurate Coarse-Grained Neural Network Molecular Dynamics

Graph neural network (GNN) potentials such as SchNet improve the accuracy and transferability of molecular dynamics (MD) simulation by learning many-body interactions, but remain slower than classical force fields due to fragmented kernels and memory-bound pipelines that underutilize GPUs. We show that a missing principle is making GNN-MD IO-aware, carefully accounting for reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. We present FlashSchNet, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques: (1) flash radial basis, which fuses pairwise distance computation, Gaussian basis expansion, and cosine envelope into a single tiled pass, computing each distance once and reusing it across all basis functions; (2) flash message passing, which fuses cutoff, neighbor gather, filter multiplication, and reduction to avoid materializing edge tensors in HBM; (3) flash aggregation, which reformulates scatter-add via CSR segment reduce, reducing atomic writes by a factor of feature dimension and enabling contention-free accumulation in both forward and backward passes; (4) channel-wise 16-bit quantization that exploits the low per-channel dynamic range in SchNet MLP weights to further improve throughput with negligible accuracy loss. On a single NVIDIA RTX PRO 6000, FlashSchNet achieves 1000 ns/day aggregate simulation throughput over 64 parallel replicas on coarse-grained (CG) protein containing 269 beads (6.5x faster than CGSchNet baseline with 80% reduction of peak memory), surpassing classical force fields (e.g. MARTINI) while retaining SchNet-level accuracy and transferability.

  • 5 authors
·
Feb 13