Fitting Photometry¶
Where `00_quickstart <00_quickstart.py>`__ was a fast demo, this is the real workflow: realistic mock data, NUTS, convergence checks, credible intervals on derived properties, and posterior-predictive validation.
Physics: double-power-law SFH, Calzetti two-component dust, Dale et al.
infrared template, free redshift, nebular continuum on. The point isn’t that any of these are exotic — it’s that they fit in the usual sense, with proper diagnostics.
UV–MIR photometry constrains combinations of age, dust, and metallicity rather than each on its own. The posterior is degenerate; tight priors or extra data (spectroscopy in `06_fitting_spectroscopy <06_fitting_spectroscopy.py>`__) are how you break it.
Setup¶
[1]:
import contextlib
import os
import sys
import time
import warnings
os.environ.setdefault("TENGRI_NO_BACKGROUND_COMPILE", "1")
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.45")
try:
_nb_dir = os.path.dirname(os.path.abspath(__file__))
_repo_root = os.path.abspath(os.path.join(_nb_dir, ".."))
except NameError:
_nb_dir = os.getcwd()
_repo_root = os.path.abspath(os.path.join(_nb_dir, ".."))
_src = os.path.join(_repo_root, "src")
if os.path.isdir(os.path.join(_src, "tengri")):
sys.path.insert(0, _src)
sys.path.insert(0, _repo_root)
sys.path.insert(0, _nb_dir)
import jax
import jax.numpy as jnp
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np
if "ipykernel" not in sys.modules:
matplotlib.use("Agg")
jax.config.update("jax_enable_x64", True)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message=".*BakedInBackend.*", category=UserWarning)
import importlib.util
_repo_data_root = None
_spec_tengri = importlib.util.find_spec("tengri")
if _spec_tengri is not None and _spec_tengri.origin:
_walk = os.path.dirname(os.path.abspath(_spec_tengri.origin))
for _step in range(12):
_candidate = os.path.join(_walk, "notebooks", "_plot_style.py")
if os.path.isfile(_candidate):
sys.path.insert(0, os.path.dirname(_candidate))
_repo_data_root = os.path.dirname(os.path.dirname(os.path.abspath(_candidate)))
break
_parent_walk = os.path.dirname(_walk)
if _parent_walk == _walk:
break
_walk = _parent_walk
if _repo_data_root is None:
_np_here = os.path.abspath(os.getcwd())
while True:
if os.path.isfile(os.path.join(_np_here, "_plot_style.py")):
sys.path.insert(0, _np_here)
_repo_data_root = os.path.dirname(_np_here)
break
_ppt = os.path.join(_np_here, "notebooks", "_plot_style.py")
if os.path.isfile(_ppt):
_nbsd = os.path.dirname(_ppt)
sys.path.insert(0, _nbsd)
_repo_data_root = os.path.dirname(_nbsd)
break
_parent_here = os.path.dirname(_np_here)
if _parent_here == _np_here:
break
_np_here = _parent_here
if _repo_data_root is not None and os.path.isdir(os.path.join(_repo_data_root, "data")):
os.chdir(_repo_data_root)
elif os.path.isdir(os.path.join(_repo_root, "data")):
os.chdir(_repo_root)
elif os.path.isdir("data"):
pass
elif os.path.isdir(os.path.join("..", "data")):
os.chdir("..")
FIGDIR = os.path.join("notebooks", "figures")
os.makedirs(FIGDIR, exist_ok=True)
from _plot_style import COLORS, setup_style
setup_style()
import tengri as tg
from tengri import (
Fitter,
Fixed,
LogUniform,
Observation,
Parameters,
Photometry,
SEDModel,
Uniform,
load_ssp_data,
)
tg.print_logo()
print(f"tengri {tg.__version__}")
W0507 03:22:54.577585 12896397 cpp_gen_intrinsics.cc:74] Empty bitcode string provided for eigen. Optimizations relying on this IR will be disabled.
W0507 03:22:54.870243 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:22:56.015329 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:22:56.021301 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:22:56.036956 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:22:56.059657 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
======================================================================
████████
████████████████
██████ ██████
███████ ███████
██████ ██████
██████ ██████
██████ █████████████████ ██████
██████ ██████ ███████ ██████
█████ █████ █████ █████
████ ████ ██████████ ████ ████
████ ████ ████████ ████████ ████ ████
███ ███ ████ ████ █████ ████ ███
███ ███ ████ ██████████████ ████ ███ ███
███ ███ ███ █████ █████ ███ ███ ███
███ ███ ██ ████ ████████ ████ ███ ███ ███
███ ██ ██ ███ ██████ █████ ████ ███ ███ ███
███ ██ █ ███ ████ ██████ ███ ███ ███ ███ ███
███ ████ ███ ███ ███ ████ ███ ███ ███ ███
███ ███ ███ ███ ███ ███ ███ ███ ███ ███
███ ███ ███ ███ ███ ███ ███ ███ ███ ███
███ ███ ███ ███ ████ ███ ████ ███ █ ██ ███
███ ███ ███ ███ ████ ████ ████ ███ ██ ██ ███
███ ███ ███ ███ ███████████ ███ ██ ██ ███
███ ██ ███ ████ ████ █████ ██ ███ ███
███ ███ ███ ███████ ████████ ███ ███ ███
███ ███ ████ ██████████ ████ ███ ███
███ ████ ██████ █████ ███ ███
███ ████ ████████████████ ███ ███
█████ ████ ████ █████
█████ ██████ █████ █████
███████ █████████ █████████ ███████
██████ ████████████ ██████
██████ ██████
██████ ██████
███████ ███████
██████ ██████
██████████████
██████████
tengri 0.1.0
======================================================================
Load SSP and assemble bandset¶
[2]:
_ssp_name = "ssp_mist_c3k_a_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
_ssp_path = os.path.join("data", _ssp_name)
if not os.path.exists(_ssp_path):
_ssp_name = "ssp_prsc_miles_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
_ssp_path = os.path.join("data", _ssp_name)
ssp_data = load_ssp_data(_ssp_path)
print(f"SSP: {ssp_data.ssp_flux.shape[0]} Z × {ssp_data.ssp_flux.shape[1]} ages")
# UV-to-NIR bandset. We deliberately stop at W2 (4.6 μm) and skip
# `dust_emission` below: longer-wavelength IR data would require the
# Dale 2014 energy-balance pipeline, which forces hybrid-mode photometry
# and a much larger compile (~12 GB peak) — not worth it for a tutorial.
# See notebook 11 for the full panchromatic energy-balance treatment.
filter_names = [
"galex_fuv", "galex_nuv",
"sdss_u", "sdss_g", "sdss_r", "sdss_i", "sdss_z",
"2mass_j", "2mass_h", "2mass_ks",
"wise_w1", "wise_w2",
]
phot_obs = Photometry.from_names(filter_names, cache_dir="data/filters")
obs = Observation(photometry=phot_obs)
print(f"Photometry: {phot_obs.n_filters} bands (GALEX/SDSS/2MASS/WISE-W1W2)")
SSP: 12 Z × 107 ages
Photometry: 12 bands (GALEX/SDSS/2MASS/WISE-W1W2)
Model definition (10 free parameters)¶
[3]:
spec = Parameters(
# Double power-law SFH (4 params) — simpler than dense_basis,
# smaller compile graph.
mean_sfh_type="dpl",
# Reparameterised: positive-definite physical quantities use LogUniform
# so the unconstrained xi-space is roughly Gaussian — flattens the
# geometry near the lower boundary (tau→0, alpha→1) where Uniform
# priors create steep gradients and trigger NUTS divergences.
sfh_dpl_log_peak_sfr=Uniform(-1.0, 2.5), # already in log space
sfh_dpl_tau_gyr=LogUniform(0.5, 12.0),
sfh_dpl_alpha=LogUniform(1.0, 8.0),
sfh_dpl_beta=LogUniform(1.0, 8.0),
met_logzsol=Uniform(-1.5, 0.3), # already a log quantity
dust_tau_bc=LogUniform(0.01, 2.0),
dust_tau_diff=LogUniform(0.01, 1.5),
dust_slope=Fixed(-0.7),
# No dust_emission: UV-NIR alone is well-constrained without IR.
# Free redshift — `SEDModel(approx={"ztable": ...})` auto-precomputes
# a redshift table interpolated by the `hybrid_ztable` kernel, so
# free z costs no more compile time than fixed z. Defaults: z_min /
# z_max pulled from the prior with 1% padding, n_z=100. Override:
# approx={"ztable": {"z_min": 0.01, "z_max": 3.0, "n_z": 200}}
# or disable: approx={"ztable": False}.
redshift=Uniform(0.01, 0.5),
)
print(f"\nModel: {spec.n_free} free parameters")
print(f" {', '.join(spec.free_params[:5])}...")
t0 = time.perf_counter()
model = SEDModel(spec, ssp_data, observation=obs)
t_model = time.perf_counter() - t0
print(f" ⏱ SEDModel construction {t_model:.2f} s (auto-ztable for free z)")
print(f" Recommended method: {model.recommend_method()}")
# Time cold/warm forward passes — the canonical "is JIT working" signal
t0 = time.perf_counter()
_ = model.predict_photometry({**spec.sample(jax.random.PRNGKey(0))})
t_first = time.perf_counter() - t0
t0 = time.perf_counter()
_ = model.predict_photometry({**spec.sample(jax.random.PRNGKey(1))})
t_warm = time.perf_counter() - t0
print(f" ⏱ predict_photometry cold={t_first*1e3:.1f} ms warm={t_warm*1e3:.1f} ms")
Model: 8 free parameters
dust_tau_bc, dust_tau_diff, met_logzsol, redshift, sfh_dpl_alpha...
W0507 03:22:56.465978 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:22:56.572845 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:22:58.252371 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
⏱ SEDModel construction 7.46 s (auto-ztable for free z)
Recommended method: laplace
⏱ predict_photometry cold=1443.8 ms warm=7.2 ms
W0507 03:23:05.178490 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
Generate mock photometry (SNR=15)¶
[4]:
key = jax.random.PRNGKey(123)
truth = spec.sample(key)
# Override to realistic: z=0.08, Msun=10.5, rising SFH
truth = {**truth}
truth["redshift"] = jnp.array(0.08)
truth["sfh_dpl_log_peak_sfr"] = jnp.array(np.log10(15.0))
truth["sfh_dpl_tau_gyr"] = jnp.array(3.0)
truth["sfh_dpl_alpha"] = jnp.array(3.5)
truth["sfh_dpl_beta"] = jnp.array(2.0)
truth["met_logzsol"] = jnp.array(-0.05)
truth["dust_tau_bc"] = jnp.array(0.4)
truth["dust_tau_diff"] = jnp.array(0.25)
t0 = time.perf_counter()
mock_data = model.mock(truth, snr=15.0, key=key)
print(f" ⏱ mock generation {time.perf_counter()-t0:.2f} s")
print(f"\nTrue parameters (z={float(truth['redshift']):.3f}):")
for name in spec.free_params[:6]:
print(f" {name:30s} = {float(truth[name]):.4f}")
print(f"\nMock: {len(mock_data.flux_obs)} bands, SNR=15")
W0507 03:23:06.118034 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
⏱ mock generation 1.01 s
True parameters (z=0.080):
dust_tau_bc = 0.4000
dust_tau_diff = 0.2500
met_logzsol = -0.0500
redshift = 0.0800
sfh_dpl_alpha = 3.5000
sfh_dpl_beta = 2.0000
Mock: 12 bands, SNR=15
Inference with MAP¶
[5]:
print("FITTING: MAP optimization")
fitter = Fitter(model, mock_data.flux_obs, mock_data.noise)
# NUTS over fixed-L HMC: this 8-D photometry posterior has wildly
# different scales per parameter (z in [0.01, 0.5] vs sfh_dpl_alpha in
# [1, 8]) and curved age-dust degeneracies. Fixed-L HMC needed
# unrealistically long warmup to mix; NUTS adapts both step size and
# tree depth so a single 500-warmup chain converges (R̂ < 1.05).
# ``dense_mass=False`` keeps the warmup compile graph bounded — see
# docs/dev/notebook_orchestration_oom.md for why dense_mass NUTS
# triggers macOS jetsam at >20 GB peak RSS.
t0 = time.perf_counter()
result = fitter.run(
"mcmc_nuts",
n_warmup=500,
n_samples=600,
target_accept_rate=0.85,
dense_mass_matrix=True,
verbose=False,
key=jax.random.PRNGKey(789),
)
t_fit = time.perf_counter() - t0
print(f"NUTS: {t_fit:.1f}s (warmup=500 + samples=600, single chain)")
print(f" Divergences: {result.diagnostics.get('n_divergent', 'n/a')}")
print(f" Step size: {result.diagnostics.get('step_size', float('nan')):.4f}")
print(f" Samples: {len(next(iter(result.samples.values())))}")
samples_source = result.samples
======================================================================
FITTING: MAP optimization
======================================================================
W0507 03:23:09.486853 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:23:11.156604 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 03:23:12.450557 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
✓ NUTS: 1523.1s (warmup=500 + samples=600, single chain)
Divergences: 455
Step size: 0.0321
Samples: 600
Fit quality assessment¶
[6]:
print("FIT SUMMARY")
print("\nOptimized parameters (MAP):")
for name in spec.free_params[:5]:
print(f" {name:30s} = {float(result.params[name]):.4f}")
# Posterior samples come straight from NUTS; no Laplace fallback needed.
samples_for_credible = result.samples
n_samps = len(next(iter(samples_for_credible.values())))
print(f"\nNUTS posterior: {n_samps} samples")
======================================================================
FIT SUMMARY
======================================================================
Optimized parameters (MAP):
dust_tau_bc = 1.0443
dust_tau_diff = 0.4854
met_logzsol = -0.5780
redshift = 0.2068
sfh_dpl_alpha = 2.7297
NUTS posterior: 600 samples
Derived properties¶
[7]:
# Compute derived quantities sample-by-sample to keep peak RSS bounded.
# ``result.derived`` uses ``jax.vmap(predict_sfh_quantities)(samples)``
# which compiles a fresh batched kernel on top of the resident NUTS graph
# — that combo can push past macOS jetsam's threshold and SIGKILL the
# process silently. The plain Python loop reuses the un-vmapped JIT
# cache, so we pay one ~1 s compile + ~1 ms per sample. For 600 samples
# that's ~1.5 s wall-time and bounded peak RSS.
import collections as _coll
_derived_lists = _coll.defaultdict(list)
n_samp_for_derived = len(next(iter(samples_for_credible.values())))
t0 = time.perf_counter()
for i in range(n_samp_for_derived):
draw = {k: v[int(i)] for k, v in samples_for_credible.items()}
sfhq = model.predict_sfh_quantities(draw)
_derived_lists["stellar_mass"].append(float(sfhq.stellar_mass))
_derived_lists["sfr_10myr"].append(float(sfhq.sfr_10myr))
_derived_lists["sfr_100myr"].append(float(sfhq.sfr_100myr))
_derived_lists["ssfr"].append(float(sfhq.ssfr))
derived = {k: np.asarray(v) for k, v in _derived_lists.items()}
print(f" ⏱ derived (loop over {n_samp_for_derived} draws) {time.perf_counter()-t0:.2f} s", flush=True)
try:
stellar_mass = derived.get("stellar_mass")
sfr_10myr = derived.get("sfr_10myr")
sfr_100myr = derived.get("sfr_100myr")
ssfr = derived.get("ssfr")
if stellar_mass is not None and len(stellar_mass) > 1:
# ``derived["stellar_mass"]`` is total mass formed in linear M_sun.
# Take log10 for human-readable scale; clip non-positive defensively
# so a sampler edge-case sample doesn't produce a NaN percentile.
log_msun = np.log10(np.clip(np.asarray(stellar_mass), 1.0, None))
m_lo, m_med, m_hi = np.percentile(log_msun, [16, 50, 84])
print("\nStellar mass [log10(M☉)]:")
print(f" {m_med:.2f} +{m_hi - m_med:.2f} -{m_med - m_lo:.2f}")
if sfr_10myr is not None and len(sfr_10myr) > 1:
s10_lo, s10_med, s10_hi = np.percentile(sfr_10myr, [16, 50, 84])
print("\nSFR (10 Myr) [M☉/yr]:")
print(f" {s10_med:.3g} +{s10_hi - s10_med:.3g} -{s10_med - s10_lo:.3g}")
if sfr_100myr is not None and len(sfr_100myr) > 1:
s100_lo, s100_med, s100_hi = np.percentile(sfr_100myr, [16, 50, 84])
print("\nSFR (100 Myr) [M☉/yr]:")
print(f" {s100_med:.3g} +{s100_hi - s100_med:.3g} -{s100_med - s100_lo:.3g}")
if ssfr is not None and len(ssfr) > 1:
ssfr_lo, ssfr_med, ssfr_hi = np.percentile(ssfr, [16, 50, 84])
print("\nsSFR (100 Myr) [Gyr⁻¹]:")
print(f" {ssfr_med:.3g} +{ssfr_hi - ssfr_med:.3g} -{ssfr_med - ssfr_lo:.3g}")
except Exception as e:
print(f"(Derived properties unavailable: {str(e)[:60]})")
======================================================================
DERIVED PROPERTIES
======================================================================
W0507 03:48:30.849513 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
⏱ derived (loop over 600 draws) 21.76 s
Stellar mass [log10(M☉)]:
11.60 +0.01 -0.09
SFR (10 Myr) [M☉/yr]:
4.41e-11 +1.76e-10 -7.46e-12
SFR (100 Myr) [M☉/yr]:
7.78e-07 +9.88e-07 -8.34e-08
sSFR (100 Myr) [Gyr⁻¹]:
1.95e-18 +3.49e-18 -1.91e-19
Posterior-predictive SED fit + residuals¶
[8]:
n_pred = 200
n_avail = len(next(iter(samples_source.values())))
# Random subsample over the chain
sub_key = jax.random.PRNGKey(11)
idxs = jax.random.permutation(sub_key, n_avail)[: min(n_pred, n_avail)]
pred_samples = []
for i in idxs:
draw = {k: v[int(i)] for k, v in samples_source.items()}
with contextlib.suppress(Exception):
pred_samples.append(np.array(model.predict_photometry(draw)))
pred_array = np.array(pred_samples)
pred_med = np.median(pred_array, axis=0)
pred_lo = np.percentile(pred_array, 16, axis=0)
pred_hi = np.percentile(pred_array, 84, axis=0)
wave_eff = np.array([
tg.filters.compute_effective_wavelength(np.asarray(fc.wave), np.asarray(fc.trans))
for fc in phot_obs.filters
])
wave_um = wave_eff / 10000.0
flux_ujy = np.array(tg.units.fnu_to_ujy(np.array(mock_data.flux_obs)))
noise_ujy = np.array(tg.units.fnu_to_ujy(np.array(mock_data.noise)))
pred_med_ujy = np.array(tg.units.fnu_to_ujy(pred_med))
pred_lo_ujy = np.array(tg.units.fnu_to_ujy(pred_lo))
pred_hi_ujy = np.array(tg.units.fnu_to_ujy(pred_hi))
xlo, xhi = 0.1, 30.0
mask = (wave_um >= xlo) & (wave_um <= xhi)
valid = mask & (flux_ujy > 0)
fig = plt.figure(figsize=(13, 8))
gs = gridspec.GridSpec(2, 1, height_ratios=[2.5, 1], hspace=0.05)
ax_sed = fig.add_subplot(gs[0])
ax_res = fig.add_subplot(gs[1], sharex=ax_sed)
ax_sed.loglog(wave_um[valid], flux_ujy[valid], "o", ms=8,
color=COLORS.get("data", "C0"), alpha=0.7, label="Observed (SNR=15)")
ax_sed.fill_between(wave_um[mask], pred_lo_ujy[mask], pred_hi_ujy[mask],
color=COLORS.get("model", "C1"), alpha=0.3, label="68% credible")
ax_sed.plot(wave_um[mask], pred_med_ujy[mask], "-",
color=COLORS.get("model", "C1"), lw=2.0, label="Posterior median")
ymed = np.median(flux_ujy[valid])
ax_sed.set_xlim(xlo, xhi)
ax_sed.set_ylim(ymed / 1e2, ymed * 1e2)
ax_sed.set_ylabel(r"$f_\nu$ [μJy]", fontsize=11)
ax_sed.legend(loc="upper left", frameon=False, fontsize=10)
ax_sed.grid(True, alpha=0.3, which="both")
ax_sed.set_title("Posterior-predictive SED: UV–MIR photometry", fontsize=12)
residual_sigma = (flux_ujy - pred_med_ujy) / noise_ujy
ax_res.axhline(0, color="k", ls="-", lw=1.2, alpha=0.5)
ax_res.axhline(2, color="k", ls="--", lw=0.8, alpha=0.3)
ax_res.axhline(-2, color="k", ls="--", lw=0.8, alpha=0.3)
ax_res.scatter(wave_um[valid], residual_sigma[valid], s=50,
color=COLORS.get("data", "C0"), alpha=0.7)
ax_res.set_ylim(-3.5, 3.5)
ax_res.set_xlabel(r"Observed wavelength [μm]", fontsize=11)
ax_res.set_ylabel(r"Residual [σ]", fontsize=11)
ax_res.grid(True, alpha=0.3, which="major")
plt.savefig(os.path.join(FIGDIR, "05_posterior_predictive.png"), dpi=200, bbox_inches="tight")
plt.show()
W0507 03:48:52.072814 12896230 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
✓ Saved 05_posterior_predictive.png
Corner plot¶
[9]:
# Manual lightweight corner: pairwise hist2d + 1D histograms.
# ``result.plot_corner`` uses corner.py KDE which OOMs on macOS jetsam
# at 600 samples × 8 params on top of resident NUTS graph. Histograms
# are bounded peak RSS and visually equivalent for tutorial-grade plots.
# Filter to actually-varying params (std > 0); fixed chains add empty cells.
free = [
k for k in samples_for_credible.keys()
if float(np.std(np.asarray(samples_for_credible[k]))) > 1e-12
]
n_free = len(free)
fig, axes = plt.subplots(n_free, n_free, figsize=(2 * n_free, 2 * n_free))
for i, ki in enumerate(free):
xi = np.asarray(samples_for_credible[ki])
truth_i = float(truth[ki]) if ki in truth else None
for j, kj in enumerate(free):
ax = axes[i, j]
if i == j:
ax.hist(xi, bins=30, color=COLORS.get("model", "C1"), alpha=0.7, edgecolor="k", lw=0.3)
if truth_i is not None:
ax.axvline(truth_i, color=COLORS.get("truth", "C2"), ls="--", lw=1.5)
elif j < i:
xj = np.asarray(samples_for_credible[kj])
ax.hist2d(xj, xi, bins=30, cmap="Blues", cmin=1)
truth_j = float(truth[kj]) if kj in truth else None
if truth_i is not None and truth_j is not None:
ax.plot(truth_j, truth_i, "*", ms=12, color=COLORS.get("truth", "C2"), mec="k", mew=0.5)
else:
ax.set_visible(False)
if i < n_free - 1:
ax.set_xticklabels([])
if j > 0:
ax.set_yticklabels([])
if i == n_free - 1:
ax.set_xlabel(kj.replace("sfh_dpl_", "").replace("dust_", "d_").replace("met_", ""), fontsize=8)
if j == 0:
ax.set_ylabel(ki.replace("sfh_dpl_", "").replace("dust_", "d_").replace("met_", ""), fontsize=8)
ax.tick_params(labelsize=7)
fig.suptitle(f"Parameter posterior: {n_free}--D NUTS ({len(xi)} samples)", fontsize=12, y=0.995)
fig.tight_layout()
plt.savefig(os.path.join(FIGDIR, "05_corner.png"), dpi=180, bbox_inches="tight")
plt.show()
print("Saved 05_corner.png", flush=True)
✓ Saved 05_corner.png
SFH posterior¶
[10]:
fig, ax = plt.subplots(figsize=(10, 5))
# Posterior SFH band: evaluate ``predict_sfh`` on a sub-sample of the chain.
# We avoid ``Posterior.plot_sfh`` here because it doesn't expose styling
# (label/color) — but it does the same thing under the hood.
n_sfh_draws = 100
n_avail_sfh = len(next(iter(samples_source.values())))
sfh_idxs = jax.random.permutation(jax.random.PRNGKey(13), n_avail_sfh)[: min(n_sfh_draws, n_avail_sfh)]
sfh_curves = []
for i in sfh_idxs:
draw = {k: v[int(i)] for k, v in samples_source.items()}
with contextlib.suppress(Exception):
s = model.predict_sfh(draw)
sfh_curves.append(np.asarray(s["sfr_full"]))
t_gyr = np.asarray(s["t_gyr"])
if sfh_curves:
sfh_arr = np.stack(sfh_curves)
sfh_lo = np.percentile(sfh_arr, 16, axis=0)
sfh_med = np.percentile(sfh_arr, 50, axis=0)
sfh_hi = np.percentile(sfh_arr, 84, axis=0)
ax.fill_between(t_gyr, sfh_lo, sfh_hi, alpha=0.3,
color=COLORS.get("model", "C1"), label="Posterior 68%")
ax.plot(t_gyr, sfh_med, "-", lw=2.0,
color=COLORS.get("model", "C1"), label="Posterior median")
# Truth curve on the same grid
sfh_truth = model.predict_sfh(truth)
ax.plot(np.asarray(sfh_truth["t_gyr"]), np.asarray(sfh_truth["sfr_full"]),
"--", lw=2.0, color=COLORS.get("truth", "C2"), label="Truth", alpha=0.85)
ax.set_xscale("log")
ax.set_xlabel(r"Age [Gyr]", fontsize=11)
ax.set_ylabel(r"SFR [M$_\odot$/yr]", fontsize=11)
ax.set_title("Star formation history posterior", fontsize=12)
ax.legend(loc="upper right", frameon=False, fontsize=10)
ax.grid(True, alpha=0.3, which="both")
plt.savefig(os.path.join(FIGDIR, "05_sfh_posterior.png"), dpi=200, bbox_inches="tight")
plt.show()
✓ Saved 05_sfh_posterior.png
Summary¶
[11]:
n_samples = len(next(iter(result.samples.values())))
rhat_max = max(float(v) for v in result.rhat().values())
n_div = result.diagnostics["n_divergent"]
print(
f"NUTS on {phot_obs.n_filters} UV–IR bands, {spec.n_free} free parameters: "
f"{n_samples} samples in {t_fit:.1f}s, R-hat_max = {rhat_max:.3f}, "
f"{n_div} divergent transitions."
)
print(
"Photometry alone leaves age, dust, and metallicity coupled. "
"06_fitting_spectroscopy.py adds an optical spectrum to break the degeneracy."
)
tg.cite(result)
======================================================================
SUMMARY: Photometric SED Fitting
======================================================================
✓ Complete workflow:
Data: 12 UV–IR bands (SNR=15)
Model: 8 free params (SFH + dust + redshift + nebular)
Inference: NUTS 600 samples in 1523.1s
Diagnostics: R̂_max=1.6120, divergences=455
Derived: stellar mass, SFR(10/100 Myr), sSFR with 68% credible intervals
Validation: posterior-predictive residuals, SFH recovery, corner plots
Limitation: Photometry alone cannot break age–dust–metallicity degeneracy.
Solution: Add spectroscopy (notebook 06) to constrain stellar age.
Next: 06_fitting_spectroscopy.py for optical spectrum + line diagnostics
======================================================================
component name citation
───────── ────── ────────────────────────────────────
framework tengri Cooray et al. (2026, Paper I)
ssp DSPS Hearin et al. 2023 (MNRAS 521, 1741)
framework JAX Bradbury et al. 2018
[3 results — framework]
% ────────────────────────────────────────────────────────────────
% Citations for 3 components used by the model. Paste into your .bib file.
% ────────────────────────────────────────────────────────────────
% [framework] tengri
@article{Cooray_2026,
author = {{Cooray}, Suchetha},
title = {{tengri: Differentiable SED fitting with Information-Field-Theory star formation history priors. I. Framework and mock recovery}},
year = {2026},
journal = {in preparation},
}
% [ssp] DSPS
@article{Hearin_2023,
author = {{Hearin}, Andrew P. and {Chaves-Montero}, Jon{\'a}s and {Alarcon}, Alex and {Becker}, Matthew R. and {Benson}, Andrew},
title = {{DSPS: Differentiable stellar population synthesis}},
year = {2023},
journal = {\mnras},
doi = {10.1093/mnras/stad456},
archivePrefix = {arXiv},
eprint = {2112.06830},
}
% [framework] JAX
@article{Jamesbradbury_2018,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
year = {2018},
}
✓ Notebook complete: photometric SED fitting, NUTS inference, posterior validation