Performance¶
The forward model is pure JAX, so every backend (MAP, NUTS, geoVI, …) runs against the same compiled computation graph. That makes “how fast is tengri?” a small set of numbers that travel together.
This page summarises what the existing benchmark suite measures, the
headline numbers from the last full run, and how to reproduce them on
your hardware. The full benchmark suite ships under
bench/scripts/benchmark_*.py
and is consolidated behind one entry point — see
Health check & dispatcher below.
Warning
The headline numbers below were measured in April–May 2026. Several have not been re-run after recent forward-model changes and may be stale. Treat them as ballpark, not authoritative; re-run the relevant script (see Reproducing the headline numbers) before quoting in a paper or PR. The bench/reports/ directory carries the date of every measurement event.
Headline numbers (Apple M-series CPU, x64, JAX 0.9, last run May 2026)¶
Forward photometric prediction on SDSS ugriz at z = 0.1, 5 bands, running on a single CPU core:
Configuration |
Exact |
Compositional |
Hybrid (precomputed) |
|---|---|---|---|
Stellar only |
23.9 ms |
1.5 ms |
59 µs (408×) |
+ nebular (BakedIn) |
24.7 ms |
1.5 ms |
58 µs (424×) |
+ nebular (Cue emulator) |
61.6 ms |
2.4 ms |
567 µs (109×) |
+ dust IR (THEMIS) |
27.3 ms |
2.0 ms |
158 µs (173×) |
+ radio + X-ray + AGN |
76.4 ms |
4.6 ms |
2.44 ms (31×) |
Kitchen sink (all emitters) |
76.1 ms |
4.6 ms |
2.45 ms (31×) |
— full table at bench/reports/2026-05-06_forward_model_speedup.md
Inference backends on a 7-parameter mock fit (compile + sample wall):
Backend |
First call |
Steady-state |
|---|---|---|
MAP (L-BFGS) |
~5 s |
< 1 s |
Laplace |
~5 s |
< 1 s |
Pathfinder |
~10 s |
~2 s |
NUTS (1k samples) |
~30 s |
~5 s |
|
~10 s |
2.3 s |
|
~75 s |
43.7 s |
— full breakdowns: 2026-04-17_native_vs_nifty.md, 2026-04-22_pathfinder_vs_window_nuts.md, 2026-05-06_compile_vs_sampling_breakdown.md
vi_native is 19–25× faster than the NIFTy path on smooth-SFH fits
but is not drop-in posterior-equivalent: PSD-timescale parameters
differ by an order of magnitude on stochastic fits. Validate per problem
before swapping.
Persistent compile cache¶
JAX recompiles XLA programs on every cold start. Tengri auto-enables a
persistent on-disk cache at ~/.cache/tengri_jax_cache so notebook
restarts, slurm tasks, and benchmark runs all skip the expensive first
compile (geoVI ~75 s, MGVI ~10 s, NUTS warmup tens of seconds).
export TENGRI_JAX_CACHE_DIR=/scratch/$USER/jax_cache # custom location
export TENGRI_DISABLE_JAX_CACHE=1 # opt out
After upgrading JAX, wipe stale entries:
import tengri
tengri.clear_cache()
Default min_compile_time_secs=5.0 keeps small SSP/dust kernels out of
the cache. See compilation_cache.md
and compilation_diagnostics.md
for full details.
Health check and dispatcher¶
A one-command quick read of your install:
python -m tengri.bench
prints the JAX backend, default device, persistent compile-cache size, and a 1-galaxy + 100-galaxy timing on SDSS ugriz. ~30 s on CPU after the cache is warm.
Every comprehensive benchmark script under scripts/ is also reachable
through one entry point:
python -m tengri.bench list # show all
python -m tengri.bench help forward_model # what does it measure?
python -m tengri.bench forward_model # run it
Available benchmarks (bench list):
Name |
What it measures |
|---|---|
|
Forward photometry: exact / compositional / hybrid across all emitters |
|
Per-component (stellar, dust, nebular, AGN, …) wall-clock timing |
|
Population-scale JIT compile time vs N galaxies |
|
Compile time on the production forward-model path |
|
MAP / Laplace / NUTS / VI / NSS at D = 7, 12, 20 |
|
geoVI: pure-JAX |
|
VI scaling on stochastic-SFH problems with D >> 100 |
|
Hierarchical PopulationFitter: per-iteration cost vs N galaxies |
|
MAP optimizers head-to-head |
|
Cue (Li+2025) nebular emulator timing in isolation |
|
Per-call loss / negative-log-posterior timing |
|
End-to-end timing for joint photometry + spectral indices |
|
Analytic precompute lookup vs full-spectrum integration |
|
Quadrature precompute: accuracy vs grid resolution |
|
Metallicity-table interpolation kernel timing |
Reproducing the headline numbers¶
JAX_PLATFORMS=cpu python -m tengri.bench forward_model
JAX_PLATFORMS=cpu python -m tengri.bench inference_engines
Each script writes its dated report to bench/reports/ (or to
stdout, depending on the script). The reports there are the source of
truth for every number quoted on this page.
bench/RERUN.md
tracks which scripts are due for a re-run.
Hardware notes¶
All numbers above are single CPU core on Apple M-series hardware. Tengri runs on JAX, so the same code executes on GPU/TPU without modification — but those platforms have not been benchmarked. See Getting Started → GPU for setup.
JAX Metal (Apple GPU) is experimental and causes test failures; CPU is the supported reference platform for benchmarks. Set
JAX_PLATFORMS=cputo be explicit.Memory: smooth D = 7 fits run in ~100 MB; stochastic D = 137 in ~1.5 GB. NUTS warmup with
dense_mass=Truepeaks 3–6× steady state on small models and can hit 20+ GB on D ≥ 8 withdense_basisSFHs; multi-fit notebooks needdense_mass=False. See Memory expectations for the full table and the two recurring OOM patterns.
When numbers look wrong¶
If python -m tengri.bench shows a much slower 1-galaxy timing than
the table above:
Confirm
x64: True(some downstream behaviour assumes 64-bit).Confirm
default device: cpu— Metal sometimes silently picks itself up and slows things down. Force CPU withJAX_PLATFORMS=cpu.Check the cache size — if it’s in the GB range with hundreds of files,
tengri.clear_cache()after a JAX upgrade is sometimes the fix.The default
benchSSP grid is whatever first matchesdata/ssp_*.h5; a multi-Z, full-α/Fe grid is meaningfully slower than theprsc_milesgrid used innotebooks/00_quickstart. The relative numbers (vmap speedup, exact-vs-hybrid ratio) are what matters.