Tensor Parallelism Viz

A diagram of one forward pass through a weight matrix that has been split column-wise across multiple GPUs, paired with a latency-comparison bar that plots compute vs all-reduce time for every configured GPU count. The input vector X feeds every shard at once; each GPU multiplies X by its column slice in parallel; an all-reduce sums the partial results; the final output vector Y appears on the right. With one GPU the all-reduce phase is skipped.

Tensor parallelism. Full 8 by 8 weight matrix on a single GPU.Idle. Ready to compute.
Tensor parallelism1 GPU · 8×8 W · 100ms
GPUs

Latency comparison

1 GPU
100ms
2 GPUs
55ms
4 GPUs
30ms
computeall-reduce
Customize
Shape
1
Playback
Layout

Installation

npx shadcn@latest add https://craftbits.dev/r/tensor-parallelism-viz.json

Usage

import { TensorParallelismViz } from "@craft-bits/viz/tensor-parallelism-viz";
 
<TensorParallelismViz />

Start on four GPUs and run the forward pass automatically:

<TensorParallelismViz defaultGpus={4} defaultPlaying />

Drive the visual from outside (controlled phase + active GPU):

<TensorParallelismViz
  gpus={4}
  phase="compute"
  activeGpu={2}
/>

Provide a custom latency table (e.g. for an 8-GPU configuration):

<TensorParallelismViz
  gpuOptions={[1, 2, 4, 8]}
  timings={{
    1: { compute: 200, comm: 0, total: 200 },
    2: { compute: 100, comm: 8, total: 108 },
    4: { compute: 50, comm: 10, total: 60 },
    8: { compute: 25, comm: 12, total: 37 },
  }}
  defaultGpus={8}
/>

Understanding the component

  1. Diagram layout. The SVG places the input vector X on the left, the sharded weight matrix W in the centre, an all-reduce node next to it, and the output vector Y on the right. Shards are offset horizontally by four pixels each so they read as separate physical devices.
  2. Lifecycle. Four phases — idle, compute, allreduce, done — gate every visual element. Arrows draw in on compute; the all-reduce block appears on allreduce and tints success once done; the output Y cells fade in only on done.
  3. Autoplay loop. When playing is true and reduced motion is off, a setInterval advances the phase every playSpeed milliseconds. The compute phase walks activeGpu from 0 to N − 1 before flipping to allreduce (or directly to done when N = 1).
  4. Controlled / uncontrolled. gpus, phase, activeGpu, and playing all follow Radix's pattern — pass the controlled prop with a handler, or use the default* counterpart. The component never double-tracks state.
  5. Latency-comparison bar. A horizontal proportional bar per option in gpuOptions reads compute (accent) and comm (warning) widths against the baseline (gpuOptions[0]) total. The current row glows accent once the forward pass reaches allreduce / done.

Props

PropTypeDefaultDescription
gpuOptionsreadonly number[][1, 2, 4]GPU-count options surfaced as buttons.
timingsRecord<number, TensorParallelismVizTiming>TENSOR_PARALLELISM_VIZ_DEFAULT_TIMINGSPer-GPU-count latency table.
gpus / defaultGpusnumber1Controlled / uncontrolled GPU count.
phase / defaultPhaseTensorParallelismVizPhase"idle"Controlled / uncontrolled phase.
activeGpu / defaultActiveGpunumber-1Active GPU index during compute. -1 highlights every shard.
playing / defaultPlayingbooleanfalseControlled / uncontrolled autoplay state.
playSpeednumber420Milliseconds between phase advances. Floored at 80 ms.
rows / colsnumber8 / 8Weight-matrix dimensions.
showGpuLabelsbooleantrueRender the "GPU 1, GPU 2, …" labels under shards.
showTimingBarbooleantrueRender the latency-comparison bar.
transitionTransitionSPRINGS.smoothOverride the spring used for cell / label transitions.

Accessibility

  • The root is a role="figure" with aria-labelledby pointing at a hidden summary so screen-reader users get a one-line description before exploring the diagram.
  • A polite live region announces phase changes — e.g. "GPU 3 multiplying its shard.", "All-reduce. Summing partial outputs across 4 GPUs.", "Done. 30 ms total."
  • The GPU buttons are real <button> elements with aria-pressed and a descriptive aria-label ("4 GPUs"). The forward-pass button mirrors its visible label through aria-label and is disabled when not idle.
  • Colour is never the only signal — the active shard gets a thicker stroke in addition to higher fill opacity; the all-reduce block changes shape; the phase chips at the bottom highlight the current phase with both colour and a filled dot.
  • Motion respects prefers-reduced-motion: reduce — every cell / arrow / all-reduce / output transition collapses to instant. Autoplay is a no-op when reduced motion is on.

Credits

  • Extracted from: craftingattention (app/src/lessons/primitives/viz/TensorParallelismViz.tsx). The source wrapped the diagram in a Widget with three modes (Explore / Predict / Challenge), useWidgetHistory undo / redo, bookmarks, a narration band keyed to a fixed 100 / 50 / 25 ms latency table, and per-shard colour palette. The library extract strips the Widget chrome and lesson modes entirely, drops the multi-hue per-GPU palette in favour of a single --cb-accent, keeps the GPU selector + run button + diagram + latency-comparison bar as the four primitive pieces, lifts the latency table to a timings prop, and exposes the full Radix-style controlled / uncontrolled API for gpus / phase / activeGpu / playing. Colours are remapped to var(--cb-accent) / var(--cb-success) / var(--cb-warning) / var(--cb-fg-*) / var(--cb-bg-*), the inline spring is replaced by SPRINGS.smooth from @craft-bits/core/motion, and the per-arrow markers are scoped by useId to avoid SVG marker-id collisions when multiple instances render on the same page.