GPU Memory Math Viz

Live, workload-agnostic breakdown of the memory required to train a model. A log-scale slider drives the parameter count; a ZeRO radiogroup demonstrates how data-parallel sharding across N GPUs redistributes the load. Dashed GPU capacity lines and a hatched overflow overlay reveal the OOM cliff; a second per-GPU bar reveals itself once ZeRO is active.

GPU memory math. 3.7B parameters. Total training memory 62 GiB.
Presets
3.7B
Weights (2B/param)
Master (4B/param)
Momentum (4B/param)
Variance (4B/param)
Gradients (2B/param)
Activations (2B/param)
Total Training Memory62 GiB
RTX 4090 (24 GiB)A100 80GB (80 GiB)H100 80GB (80 GiB)6.8 GiB14 GiB14 GiB14 GiB6.8 GiB6.8 GiBDOESN'T FIT
ZeRO Stage
Memory formula
Weights = 3.7B x 2B = 6.8 GiBMaster = 3.7B x 4B = 14 GiBMomentum = 3.7B x 4B = 14 GiBVariance = 3.7B x 4B = 14 GiBGradients = 3.7B x 2B = 6.8 GiBActivations = 3.7B x 2B = 6.8 GiB
Total = 3.7B x 18 bytes = 62 GiB
Left / Right: adjust1-4: presetsZ: cycle ZeRO
GPU memory math. 3.7B parameters. Total training memory 62 GiB.
Customize
Model
3.7B
Distributed
8
none

Installation

npx shadcn@latest add https://craftbits.dev/r/gpu-memory-math-viz.json

Usage

import { GpuMemoryMathViz } from "@craft-bits/viz/gpu-memory-math-viz";
 
<GpuMemoryMathViz defaultParams={7e9} defaultZeroStage="none" />

Drive the parameter count + ZeRO stage from outside (controlled mode):

const [params, setParams] = useState(7e9);
const [zeroStage, setZeroStage] = useState<"none" | "stage1" | "stage2" | "stage3">("none");
 
<GpuMemoryMathViz
  params={params}
  onParamsChange={setParams}
  zeroStage={zeroStage}
  onZeroStageChange={setZeroStage}
/>;

Swap in a custom memory table (e.g. quantised inference, not Adam training):

<GpuMemoryMathViz
  segments={[
    { id: "weights", label: "INT4 Weights", shortLabel: "Weights", bytesPerParam: 0.5, shardFrom: "stage3" },
    { id: "kv", label: "FP16 KV Cache", shortLabel: "KV", bytesPerParam: 1.5, shardFrom: "stage3" },
  ]}
  presets={[
    { label: "7B", params: 7e9, key: "1" },
    { label: "70B", params: 70e9, key: "2" },
  ]}
/>

Understanding the component

  1. Log-scale slider. Parameters span ~3 orders of magnitude (100M → 70B), so the slider operates on log10(params). Pointer position maps to log space, then back to params on read — so equal slider distance feels like equal "model size" in conceptual terms.
  2. Stacked memory bar. Each segment contributes params x bytesPerParam bytes. Segments are stacked horizontally, normalised against max(totalGiB, maxGpuRef) x 1.15 so the bar always shows context (the next GPU ceiling) even when the workload comfortably fits.
  3. GPU capacity lines. Every gpuRefs[] entry paints a dashed line at its capacity in the bar's coordinate space and labels itself below. When the total crosses a line, a "DOESN'T FIT" badge appears above and the overflow region of every segment gets a hatched overlay.
  4. ZeRO stages. Each segment declares its shardFrom stage. At a given ZeRO stage, sharded segments are divided across gpuCount GPUs; non-sharded segments remain full. The component reveals a second bar showing the per-GPU breakdown, plus chips that label each segment as name / N (sharded) or name (full).
  5. Controlled / uncontrolled. params + onParamsChange and zeroStage + onZeroStageChange follow the Radix pattern — pass both for controlled, omit both for self-managed.
  6. Reduced motion. Under prefers-reduced-motion: reduce, every bar transition collapses to duration: 0. The ZeRO detail still appears but enters without movement.

Props

PropTypeDefaultDescription
paramsnumberControlled parameter count. Pair with onParamsChange.
defaultParamsnumber7e9Uncontrolled initial parameter count.
onParamsChange(next: number) => voidFires when the slider or a preset chip changes the count.
zeroStage"none" | "stage1" | "stage2" | "stage3"Controlled ZeRO stage.
defaultZeroStage"none" | "stage1" | "stage2" | "stage3""none"Uncontrolled initial ZeRO stage.
onZeroStageChange(next) => voidFires when the ZeRO stage changes.
minParamsnumber100e6Lower bound of the log-scale slider.
maxParamsnumber70e9Upper bound of the log-scale slider.
gpuCountnumber8Number of GPUs assumed for ZeRO sharding.
segmentsreadonly GpuMemoryMathVizSegment[]Adam tableMemory cost segments (label, bytes/param, shardFrom).
gpuRefsreadonly GpuMemoryMathVizGpuRef[]A100 / H100 / 4090GPU capacity reference lines.
presetsreadonly GpuMemoryMathVizPreset[]GPT-2 / 7B / 13B / 70BQuick-preset chips.
transitionTransitionSPRINGS.smoothOverride the bar-segment growth spring.
classNamestringMerged onto the root via cn().

Accessibility

  • The root is role="figure" with an aria-labelledby summary that names the parameter count, the total training memory, and (when ZeRO is active) the per-GPU memory.
  • The summary is mirrored in a live region (aria-live="polite") so screen-reader users hear the same update as sighted users when the slider moves.
  • The parameter slider is a native <input type="range"> with aria-valuemin / aria-valuemax / aria-valuenow / aria-valuetext mirroring its current state and a visible label.
  • The ZeRO toggle is a proper role="radiogroup" with aria-checked on each role="radio" button.
  • Keyboard support: Left / Right move the slider one log-step at a time, 1n jump to preset chips, Z cycles ZeRO stages.
  • Focus styling on every control uses :focus-visible with a 2px ring offset from the surface so it remains visible against both light and dark themes.
  • Colour is never the only signal — the per-segment GiB readouts, the "DOESN'T FIT" badges, and the per-GPU chips are all text.
  • Motion respects prefers-reduced-motion: reduce — bar transitions snap and the ZeRO detail enters without movement.

Credits

  • Extracted from: craftingattention (app/src/lessons/primitives/systems/GPUMemoryMathViz.tsx). The source baked in the Adam optimizer memory table, two hardcoded GPU configs, four hardcoded LLM presets, a lesson-specific narration paragraph, and raw var(--color-*) lesson tokens. The viz extract drops the narration (lesson chrome), generalises every list (segments, gpuRefs, presets, gpuCount, slider bounds) to a prop with a sensible default, remaps the palette to var(--cb-*) semantic tokens, swaps inline SPRINGS.snappy / .gentle + STAGGER.tight references for the canonical SPRINGS.snap / .smooth + STAGGER constant from @craft-bits/core/motion, surfaces params and zeroStage via the controlled/uncontrolled Radix pattern, and replaces ad-hoc useReducedMotion() with the library's usePrefersReducedMotion hook.