Tengri

A JAX framework for differentiable galaxy SED fitting. One modular forward model spans stars, dust, nebular emission, AGN, and IGM — from X-ray to radio. Every inference method (MAP, Laplace, Pathfinder, NUTS, Ray Tracing, Bayesian evidence, hierarchical population) runs on the same model, with gradients available everywhere.


The name Tengri comes from the all-encompassing God of Heaven in traditional Turkic, Mongolic, and other Central Asian nomadic religions. A fitting name for a code that models the light of galaxies across cosmic time. This name is chosen with respect for the cultural and spiritual traditions it originates from; no religious claim or appropriation is intended.


Status: v0.1.0, active development. Core pipeline functional with 2000+ tests. Paper I in preparation.

Why tengri

The forward model is pure JAX end-to-end: the same code that produces an SED also gives you its gradient. That makes every inference method — MAP, Laplace, Pathfinder, NUTS, Ray Tracing, geoVI, nested sampling — a thin wrapper over the same model. There are no separate fast/slow paths, no Fortran/C extensions to keep in sync, no manual derivatives to maintain.

The physics is modular: stars (DSPS SSPs — BC03, BPASS, FSPS, ProGeny), SFH (parametric, non-parametric, and IFT correlated-field), dust attenuation and emission, nebular (BakedIn / CloudyGrid / Cue), a unified AGN block (disc + torus + BLR/NLR), IGM absorption, radio, X-ray. Each component is a pure function you can swap, vmap, or differentiate without touching the rest of the pipeline.

On a smooth 7-parameter model the forward call runs in ~140 μs on CPU and the gradient in ~56 μs. The same source runs unchanged on GPU and TPU. SSPs come from any HDF5 file matching the DSPS schema; pre-formatted grids are mirrored here.

Paper I covers the framework and parametric mock recovery. Paper II introduces stochastic, IFT correlated-field SFHs with PSD-governed burstiness priors, fit through geoVI.

Installation

pip install -e .              # core (JAX, DSPS, NIFTy)
pip install -e ".[all]"       # + BlackJAX (NUTS) + optax (MAP)
pip install -e ".[dev]"       # + pytest, ruff, jupytext

Requirements: Python ≥ 3.10, JAX ≥ 0.4.20, DSPS ≥ 0.3, NIFTy.re ≥ 8.5.

Quick start

import jax
from tengri import (
    SEDModel, Parameters, Fitter,
    Uniform, Gaussian,
    Observation, Photometry, load_ssp_data,
)

ssp = load_ssp_data("data/ssp_fsps_v3.2.h5")
obs = Observation(photometry=Photometry.from_names(
    ["sdss_u", "sdss_g", "sdss_r", "sdss_i", "sdss_z"]
))

spec = Parameters(
    sfh_tsnorm_log_peak_sfr=Uniform(-1, 2),
    sfh_tsnorm_peak_lbt_gyr=Uniform(1, 12),
    sfh_tsnorm_width_gyr=Uniform(0.5, 5),
    met_logzsol=Gaussian(-0.3, 0.2),
    dust_tau_bc=Uniform(0, 4),
    redshift=0.1,
)

model = SEDModel(spec, ssp, observation=obs)

key = jax.random.PRNGKey(0)
mock = model.mock(spec.sample(key), key=key)

fitter = Fitter(model, mock["flux_obs"], mock["noise"])
result = fitter.run("mcmc_nuts")
print(result.summary_table())

The full walkthrough — including how the mock is constructed, what the priors do, and how to read the corner plot — is in notebooks/00_quickstart.py.

Inference methods

Method

Call

Best for

MAP

fitter.run("map")

Point estimates, initialization

Laplace

fitter.run("laplace")

Gaussian posterior from Hessian at MAP

Pathfinder

fitter.run("pathfinder")

Fast approximate posterior; good NUTS warm-start

NUTS

fitter.run("mcmc_nuts")

Gold-standard posterior (D ≲ 30)

Ray Tracing

fitter.run("mcmc_raytrace")

Exact MCMC, noise-robust, scales past D = 30

Evidence (NSS)

fitter.run("evidence")

Bayesian evidence for model comparison

Population

PopulationFitter(...)

Shared hyperparameters across galaxy samples

geoVI / vi_native

fitter.run("vi") / "vi_native"

Paper II preview. High-D stochastic SFHs (D ≈ 137+)

Method choice is introduced in notebooks/05_fitting_photometry.py; a deeper walkthrough will land in a future spine notebook.

Tutorial spine

Tutorials live as Jupytext .py files in notebooks/ and are synced to docs/spine/*.ipynb via python scripts/sync_spine_notebooks_for_docs.py.

The spine is written for astronomers — physics framing, copy-paste-able code cells, progressive teaching across notebooks. Start with 00 and 01, then branch based on your use case.

Performance (Apple M-series CPU)

Operation

Smooth (D=7)

Stochastic (D=137)

Forward model

140 μs

356 μs

Gradient

56 μs

63 μs

Dependencies

Package

Role

Required

JAX

Autodiff, JIT, GPU

Yes

DSPS

Differentiable SPS

Yes

NIFTy.re

geoVI / MGVI

Yes

NumPy

Array utilities

Yes

Matplotlib

Plotting

Yes

h5py

SSP I/O

Yes

BlackJAX

NUTS / HMC

Optional

optax

MAP optimization

Optional

References

License: BSD-3-Clause

Physics deep dives