GPU memory is the single most constraining resource in modern deep learning. Every design decision in a training or inference system — batch size, sequence length, model architecture, parallelism strategy — ultimately traces back to GPU memory capacity and bandwidth. A 70-billion-parameter model in BF16 precision consumes 140 GB for parameters alone, before accounting for optimizer states, activations, and gradients. Understanding exactly how memory is consumed, and the techniques available to reduce it, is fundamental knowledge for any ML infrastructure engineer.
This article provides a detailed technical treatment of GPU memory management for deep learning workloads, covering the memory hierarchy of modern GPUs, the breakdown of memory consumers during training, and a systematic review of the techniques that have made training of trillion-parameter models possible on finite hardware. The same principles apply, with different tradeoffs, to inference workloads where KV cache management introduces its own set of constraints.
The GPU Memory Hierarchy
Modern data-center GPUs have a multi-level memory hierarchy. At the top sits the on-chip SRAM — small but extremely fast. The H100 has 50 MB of L2 cache and approximately 33 MB of L1/shared memory across its streaming multiprocessors (SMs). This on-chip memory operates at several terabytes per second of bandwidth and is where active computational kernels work. Below it sits HBM (High Bandwidth Memory), the primary GPU memory visible to PyTorch and TensorFlow. The H100 SXM5 offers 80 GB of HBM3 at approximately 3.35 TB/s bandwidth — already a significant step down from on-chip speeds.
The memory bandwidth gap between HBM and on-chip SRAM explains why memory-bound operations (those that read large tensors and perform modest compute on them) are slower than compute-bound operations. The arithmetic intensity of an operation — FLOPs divided by bytes accessed — determines whether it is bound by compute throughput or memory bandwidth. For attention with short sequences, the bottleneck is memory bandwidth, not the H100's 989 TFLOPS of BF16 tensor compute. Flash Attention's central contribution is restructuring the attention computation to improve arithmetic intensity by keeping intermediate data in SRAM.
Beyond HBM, there is host DRAM (CPU memory, typically hundreds of GB), NVMe storage, and NVLink-connected peer GPU memory. Techniques like ZeRO-Infinity and CPU offloading exploit this hierarchy deliberately, trading the much lower bandwidth of PCIe (64 GB/s) and NVMe for vastly larger capacity. Understanding this hierarchy helps practitioners choose the right technique: if memory is insufficient and bandwidth can absorb the transfer cost, offloading is viable; if the bottleneck is compute, reducing numerical precision is preferable.
Memory Consumers During Training
During model training, GPU memory is consumed by five distinct categories. Understanding each is essential for estimating memory requirements and identifying optimization targets. Model parameters are the most straightforward: a model with P parameters in 32-bit float consumes 4P bytes, or 2P bytes in BF16/FP16 mixed-precision base. For a 7B-parameter model in BF16, this is 14 GB; for GPT-3 at 175B parameters, it is 350 GB.
Optimizer states are often the largest category for standard training. The Adam optimizer maintains a first-moment (mean gradient) and second-moment (uncentered variance of gradients) for each parameter, both in 32-bit float in standard implementations. This adds 8 bytes per parameter, tripling the base parameter footprint. The full Adam state for a 7B-parameter model is 56 GB (14 GB parameters + 14 GB FP32 main copy + 28 GB moments). ZeRO Stage 2 and Stage 3 shard these optimizer states across the data-parallel group to distribute this burden.
Gradients during the backward pass consume memory equal to the parameter count in the same precision as parameters — typically 2 bytes per parameter in mixed-precision training. Activations are often the most dynamic and difficult-to-predict consumer: they are the intermediate tensors computed during the forward pass that must be retained for use in the backward pass. For transformer models, activation memory scales with batch size, sequence length, number of layers, and hidden dimension, and can easily exceed parameter memory for large batches or long sequences. The final category is temporary workspace: scratch buffers used by cuDNN and other kernels during computation, which typically occupy a few hundred MB.
Activation Checkpointing
Activation checkpointing (also called gradient checkpointing or rematerialization) is the primary technique for trading compute for memory. Without checkpointing, the activations from every layer must be retained in memory from their computation in the forward pass until they are consumed in the backward pass. For a 40-layer transformer with a large batch, this can require tens or hundreds of gigabytes just for activations.
With activation checkpointing, only a subset of activations — the "checkpoints" — are retained. The remaining activations are discarded after the forward pass and recomputed on demand during the backward pass from the nearest checkpoint. The most common granularity is per-transformer-block: activations at the output of each transformer layer are retained, while all intermediate activations within each layer (attention intermediate tensors, MLP intermediate values) are recomputed. This reduces activation memory from O(L × B × T × D) to O(L × B × T × d_checkpoint), where d_checkpoint is much smaller than the full intermediate tensor size.
The compute cost of activation checkpointing is approximately one additional forward pass per training step — roughly a 33% increase in compute per step. For large model training where memory is the binding constraint, this tradeoff is almost always worthwhile. PyTorch's torch.utils.checkpoint.checkpoint function implements this transparently for arbitrary submodules. More selective checkpointing strategies (checkpointing only the most memory-intensive operations, such as attention) can recover most of the memory savings with a smaller compute overhead.
Mixed Precision Training and Memory Implications
Mixed precision training with BF16 or FP16 reduces memory consumption for parameters and activations by 50% compared to FP32, but requires careful management of numerical precision. The standard approach, first described by NVIDIA in the 2017 "Mixed Precision Training" paper, maintains a master copy of parameters in FP32 for the optimizer update while using FP16 for the forward and backward passes. This gives the memory savings of lower precision arithmetic while preserving the numerical stability needed for accurate optimizer steps.
BF16 (Brain Float 16) has become the dominant format for LLM training over FP16. BF16 has the same 8-bit exponent as FP32, giving it the same dynamic range, while FP16's 5-bit exponent can lead to overflow or underflow with the large gradient magnitudes that occur during training. BF16's lower mantissa precision (7 bits vs FP16's 10 bits) is generally acceptable for training, and the elimination of the loss scaling heuristics required by FP16 simplifies training recipes. A100 and H100 GPUs have native BF16 tensor cores that match the performance of FP16.
FP8 training, introduced with the H100 and its FP8 tensor cores, pushes precision reduction further. NVIDIA's Transformer Engine implements FP8 training with dynamic scaling per tensor, achieving nearly 2x the throughput of BF16 on the H100's FP8 tensor cores while maintaining model quality comparable to BF16 training for most workloads. The memory savings of FP8 over BF16 are meaningful for activation storage, but the FP32 optimizer state requirement means the overall memory reduction for a full training configuration is more modest than the raw 2x figure suggests.
ZeRO: Zero Redundancy Optimizer
Microsoft's ZeRO (Zero Redundancy Optimizer) family of techniques, integrated into DeepSpeed, partitions optimizer states, gradients, and parameters across the data-parallel group to eliminate redundant storage. In standard data-parallel training, every GPU holds a complete copy of all parameters, gradients, and optimizer states — this is the redundancy ZeRO eliminates.
ZeRO Stage 1 partitions optimizer states across data-parallel ranks, reducing the optimizer state memory by the data-parallel degree N. Each rank stores and updates only 1/N of the optimizer states. ZeRO Stage 2 additionally partitions gradients: each rank accumulates gradients only for its partition, triggering the optimizer update and then broadcasting the updated parameters. ZeRO Stage 3 goes further by also partitioning the parameters themselves: each rank holds only 1/N of the parameters, gathering them on demand for forward and backward passes. For a 1024-GPU data-parallel configuration, ZeRO Stage 3 reduces the per-GPU memory for parameters, gradients, and optimizer states by approximately 1024x — enabling single-node scale models to be trained across large clusters without model parallelism.
ZeRO-Infinity extends this to NVMe storage, enabling models with hundreds of billions of parameters to be trained on clusters where aggregate HBM would be insufficient. The bandwidth requirement limits ZeRO-Infinity to smaller batch sizes and lower hardware utilization but makes otherwise impossible training configurations feasible. ZeRO-Offload specifically targets CPU offloading of optimizer states, suitable for configurations where CPU DRAM is abundant and compute dominates.
KV Cache Management in Inference
In autoregressive LLM inference, the KV cache stores the key and value tensors from all previous tokens in the sequence to avoid recomputing them at each decoding step. For a model with L layers, H attention heads, D head dimension, and a sequence of length T, the KV cache size is 2 × L × H × D × T × bytes_per_element. For Llama-2 70B serving a batch of 128 sequences at 2048 tokens, the KV cache alone requires approximately 80 GB — as large as the model parameters.
PagedAttention, introduced with vLLM, addresses KV cache fragmentation by managing cache memory in fixed-size pages analogous to virtual memory in operating systems. Rather than pre-allocating contiguous memory for each sequence's maximum possible length, PagedAttention allocates pages on demand and maintains a page table mapping logical sequence positions to physical memory pages. This allows the same physical memory to serve more concurrent sequences, improving GPU utilization and throughput significantly. Continuous batching (also implemented in vLLM and TensorRT-LLM) further improves efficiency by dynamically adding and removing sequences from the active batch as they complete, rather than waiting for all sequences to finish before starting the next batch.
Key Takeaways
- GPU memory is consumed by five categories: parameters, optimizer states, gradients, activations, and temporary workspace — each with different optimization strategies.
- Activation checkpointing trades ~33% additional compute for substantial activation memory savings, often making the difference between a feasible and infeasible training configuration.
- BF16 mixed-precision training halves parameter and activation memory vs FP32 while maintaining FP32 optimizer states; FP8 pushes further at the cost of more complex scaling management.
- ZeRO Stage 3 partitions all memory consumers across the data-parallel group, reducing per-GPU memory by the group size and enabling very large models on clusters with limited per-GPU HBM.
- KV cache management is the dominant memory challenge in LLM inference; PagedAttention and continuous batching dramatically improve GPU utilization for serving workloads.
- Flash Attention's SRAM-tiled computation eliminates the O(L²) attention matrix from HBM, making long-context training and inference practical.
Conclusion
GPU memory management is not a single technique but a multi-layered discipline spanning hardware architecture, numerical precision, algorithmic restructuring, and system-level partitioning. The combination of activation checkpointing, mixed-precision training, ZeRO optimizer sharding, and Flash Attention has enabled the training of models that would have seemed impossible just a few years ago. As models continue to grow and context lengths extend to hundreds of thousands of tokens, memory management will remain one of the most active areas of ML systems research and engineering. Building a deep understanding of each layer of the memory hierarchy and each technique's tradeoffs is essential preparation for designing efficient AI compute systems.
Continue Reading
Explore more insights on AI infrastructure and distributed computing.
View All Articles