Optimizing NVIDIA FP4 Kernels on B200
How Kai Optimized the Only Path to 4-Bit GPU Acceleration
Multi-phase optimization • NVIDIA B200 Blackwell • 153.8x speedup achieved
The Problem: 4-Bit Promise, Orchestration Reality
NVIDIA's FP4 format offers 4x memory bandwidth improvement over FP16 for inference workloads. The GPU Mode NVFP4 GEMM competition challenges participants to implement a block-scaled FP4 matrix-multiply kernel optimized for NVIDIA B200. The only available FP4 compute route runs through CUTLASS block-scaled kernels, exposed via torch._scaled_mm.
The competition's reference kernel calls torch._scaled_mm in a loop, converting scale factors on the CPU each iteration before transferring them to the GPU. The result: a painful 24,888.8µs.
We pointed Kai at this competition to see what an iterative optimization approach could find.
Optimization Setup
Kai was configured with MAP-Elites island-based optimization and a multi-model LLM ensemble:
max_iterations: 500 # per optimization phasedatabase: num_islands: 4 population_size: 100 archive_size: 40 migration_interval: 20 feature_dimensions: - "geo_speedup" - "mean_speed_of_light_ratio" - "max_abs_error"llm: models: - name: "z-ai/glm-4.6" # weight: 0.6 - name: "anthropic/claude-opus-4.5" # weight: 0.2 - name: "google/gemini-3-pro-preview"# weight: 0.1 - name: "anthropic/claude-sonnet-4.5"# weight: 0.1diff_based_evolution: trueKai ran across multiple phases, each up to 500 iterations, with the best kernel from each phase seeding the next. Evaluation happened on NVIDIA B200 GPUs via Modal cloud deployment, using the competition's three benchmark shapes:
BENCHMARKS = [ ((7168, 16384, 1), 11, 8.622), # Large single GEMV ((4096, 7168, 8), 29, 17.275), # Medium batch-8 ((7168, 2048, 4), 47, 4.317), # Wide batch-4]Phase Progression: What Kai Discovered
The experiment ran through multiple phases, each exploring different optimization strategies. The performance progression, drawn from the actual phase file docstrings:
| Phase | Strategy | Speedup | Latency | Key Change |
|---|---|---|---|---|
| 6 | Forced Triton kernel | — | incorrect | Attempted raw Triton FP4 handling; most mutations crashed |
| 4→8 | Open-ended optimization | 145.9x→157.9x | 156µs→138µs | GPU-resident scale factors, unrolled assignment |
| 8→9 | PyTorch refinement | 157.9x→161.2x | 138µs→131µs | Inlined scale conversion, torch.stack batching |
| 9→10 | Extreme PyTorch | 161.2x+ | 131µs+ | List comprehension approach, .select(1,0) |
The Dead End: Raw Triton (Phase 6)
Failed approachThe first approach tried writing a raw Triton kernel to handle FP4 directly:
@triton.jitdef _triton_gemv_kernel( a_ptr, b_ptr, c_ptr, sfa_ptr, sfb_ptr, M, K, ..., BLOCK_M: tl.constexpr, BLOCK_K: tl.constexpr,): # Naive: cast FP4 packed uint8 to float32 and accumulate a_tile = tl.load(a_ptrs, mask=a_mask, other=0.0).to(tl.float32) b_vals = tl.load(b_ptrs, mask=mask_k, other=0.0).to(tl.float32) acc += tl.sum(a_tile * sfa_val * b_vals[None, :] * sfb_val, axis=1)This didn't work — Triton couldn't correctly interpret float4_e2m1fn_x2 packed tensors. Most mutations produced incorrect results. Kai abandoned this direction and pivoted.
The Breakthrough: GPU-Resident Scale Factors (Phase 4→8)
The pivotal insight: the bottleneck wasn't the CUTLASS kernel — it was the orchestration. The naive reference was calling to_blocked() on the CPU for every scale factor slice, then transferring to GPU:
# Naive: CPU-side conversion helldef ref_kernel(data): a_ref, b_ref, sfa_ref_cpu, sfb_ref_cpu, _, _, c_ref = data _, _, l = c_ref.shape for l_idx in range(l): # CPU conversion every iteration! scale_a = to_blocked( sfa_ref_cpu[:, :, l_idx] ).to(device=a_ref.device) scale_b = to_blocked( sfb_ref_cpu[:, :, l_idx] ).to(device=a_ref.device) res = torch._scaled_mm(...) c_ref[:, 0, l_idx] = res[:, 0] return c_ref# Evolved: GPU-native scale factorsdef custom_kernel(data): a, b, sfa_host, sfb_host, sfa_permuted, sfb_permuted, c_out = data del sfa_host, sfb_host # Don't even touch CPU versions sfa_blocked = sfa_permuted .permute(5,2,4,0,1,3) .reshape(l, -1) sfb_blocked = sfb_permuted .permute(5,2,4,0,1,3) .reshape(l, -1) b_t = b.permute(2,1,0)Result: 24,888.8µs → 156µs. The entire 100x+ gain came from eliminating data movement, not changing the math.
The Refinement: L-Specific Unrolling (Phase 8→9)
With the core bottleneck solved, Kai turned to micro-optimization. It discovered that different L values (the batch dimension in the competition shapes: L=1, L=4, L=8) benefited from fundamentally different strategies.
The L=1 case avoids batch scale factor conversion entirely:
if l == 1: scale_a = sfa_permuted.select(5, 0).permute(2, 4, 0, 1, 3).reshape(-1) scale_b = sfb_permuted.select(5, 0).permute(2, 4, 0, 1, 3).reshape(-1) gemv = torch._scaled_mm(a.select(2, 0), b.select(2, 0).t(), scale_a, scale_b, bias=None, out_dtype=torch.float16) return gemv[:, :1].view(m, 1, 1)For L=2 and L=4, Kai chose torch.stack to batch the results:
if l == 4: g0 = torch._scaled_mm(a.select(2, 0), b_t[0], sfa_blocked[0], sfb_blocked[0], bias=None, out_dtype=torch.float16)[:, 0] g1 = torch._scaled_mm(a.select(2, 1), b_t[1], sfa_blocked[1], sfb_blocked[1], bias=None, out_dtype=torch.float16)[:, 0] g2 = torch._scaled_mm(a.select(2, 2), b_t[2], sfa_blocked[2], sfb_blocked[2], bias=None, out_dtype=torch.float16)[:, 0] g3 = torch._scaled_mm(a.select(2, 3), b_t[3], sfa_blocked[3], sfb_blocked[3], bias=None, out_dtype=torch.float16)[:, 0] return torch.stack([g0, g1, g2, g3], dim=1).unsqueeze(1)For L=8, it pre-selected all A slices into named variables to avoid repeated indexing:
if l == 8: a0, a1, a2, a3 = a.select(2, 0), a.select(2, 1), a.select(2, 2), a.select(2, 3) a4, a5, a6, a7 = a.select(2, 4), a.select(2, 5), a.select(2, 6), a.select(2, 7) g0 = torch._scaled_mm(a0, b_t[0], sfa_blocked[0], sfb_blocked[0], bias=None, out_dtype=torch.float16)[:, 0] # ... g1 through g7 ... return torch.stack([g0, g1, g2, g3, g4, g5, g6, g7], dim=1).unsqueeze(1)The general-case fallback uses 4-wide loop unrolling:
result = torch.empty_like(c_out)i = 0while i + 3 < l: gemv0 = torch._scaled_mm(a.select(2, i), b_t[i], sfa_blocked[i], sfb_blocked[i], bias=None, out_dtype=torch.float16) gemv1 = torch._scaled_mm(a.select(2, i+1), b_t[i+1], sfa_blocked[i+1], sfb_blocked[i+1], bias=None, out_dtype=torch.float16) gemv2 = torch._scaled_mm(a.select(2, i+2), b_t[i+2], sfa_blocked[i+2], sfb_blocked[i+2], bias=None, out_dtype=torch.float16) gemv3 = torch._scaled_mm(a.select(2, i+3), b_t[i+3], sfa_blocked[i+3], sfb_blocked[i+3], bias=None, out_dtype=torch.float16) result[:, 0, i] = gemv0[:, 0] result[:, 0, i+1] = gemv1[:, 0] result[:, 0, i+2] = gemv2[:, 0] result[:, 0, i+3] = gemv3[:, 0] i += 4Result: 156µs → 131µs (Phase 8→9), with the final benchmarked result of 161.9µs geometric mean across all three competition shapes (3 runs, 2 warmup + 5 timing each).
The cuBLAS Investigation
To contextualize these results, we built a cuBLAS comparison (cublas_comparison.py). We tested under PyTorch 2.8.0 with cuBLAS 12.8.4.1 on the same B200 hardware. Under this version, cuBLAS does not expose a block-scaled FP4 GEMM path — even basic FP4 tensor operations like .to(torch.float4_e2m1fn_x2) fail with copy_ not implemented. This required unpacking FP4 to FP16 via a lookup table and applying scale factors manually before calling torch.mm:
torch.backends.cuda.matmul.allow_tf32 = False_FP4_E2M1_LUT = torch.tensor([ 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0,], dtype=torch.float32)def custom_kernel(data): a, b, sfa_host, sfb_host, sfa_perm, sfb_perm, c_out = data result = c_out.clone() device = a.device _, _, l = c_out.shape for idx in range(l): a_fp16 = _unpack_fp4x2_to_fp16(a[:, :, idx], device) * _expand_scales(sfa_host[:, :, idx], device) b_fp16 = _unpack_fp4x2_to_fp16(b[:, :, idx], device) * _expand_scales(sfb_host[:, :, idx], device) gemm = torch.mm(a_fp16, b_fp16.transpose(0, 1)) result[:, 0, idx] = gemm[:, 0] return resultBoth kernels were evaluated through the same _service_evaluate pipeline on the same B200 GPU, with the same warmup/timing configuration.
Final Results
What Kai Found Out
CUTLASS via torch._scaled_mm is the only working FP4 path
Under PyTorch 2.8.0 / cuBLAS 12.8.4.1, FP4 operations fail at the basics. NVIDIA announced cuBLAS 12.9 FP4 support, but it is not yet available in PyTorch distributions. Kai found the only path that actually exists.
Orchestration dominates performance
The 153.8x speedup came entirely from eliminating CPU-side scale factor conversion and reducing Python dispatch overhead. The underlying torch._scaled_mm calls are identical.
Triton was a dead end
Despite the config explicitly nudging Kai toward Triton kernels, FP4 packed tensor support wasn’t mature enough. Kai correctly abandoned this direction and found a PyTorch-level solution.
Specialization beats generalization
L-specific unrolled paths with torch.stack batching outperform a single generic loop, even though the difference is “just” Python-level overhead.
The optimized FP4 kernel is competitive with FP16 cuBLAS
161.9µs vs 125.1µs (1.29x gap) while operating on 4x smaller data. In production inference where data is already in FP4 format, the optimized kernel avoids conversion overhead entirely.
Benchmarking Infrastructure
Evaluation ran on NVIDIA B200 via Modal cloud deployment:
APP_NAME = "openevolve-gpu-eval-nvfp4-simple"image = ( modal.Image.debian_slim(python_version="3.11") .apt_install("zlib1g-dev", "libxml2-dev", "build-essential") .pip_install("torch==2.8.0", index_url="https://download.pytorch.org/whl/cu128") .pip_install("triton==3.3.0", "numpy>=1.22.0"))@app.function(image=image, gpu="B200", timeout=240)def evaluate_nvfp4_on_b200(program_code: str, warmup_runs: int = 2, timing_runs: int = 3) -> dict: return _service_evaluate(program_code, "B200", warmup_runs, timing_runs)Timing uses statistics.median per shape, then math.exp(statistics.fmean(math.log(v) for v in vals)) (geometric mean) across shapes. Final numbers are 3-run averages of this geometric mean.
Competition organized by GPU Mode. NVIDIA B200 evaluation via Modal cloud deployment, PyTorch 2.8.0 + cuBLAS 12.8.4.1. Note: NVIDIA announced native block-scaled FP4 support in cuBLAS 12.9, but this is not yet available in PyTorch distributions as of testing.