Inference engine for differentiable SED fitting with flexible method dispatch.
Separates inference strategy from the forward model by building a loss
function from the SEDModel’s predictions and the Parameters’ priors, then
running the chosen optimizer/sampler. Supports point estimation (MAP,
Laplace), gradient-free and gradient-based sampling (ESS, NUTS, Ray
Tracing, MCMC), variational inference (geoVI, MGVI), and nested sampling
(NSS) via a unified run(method) interface.
Parameters:
model (SEDModel) – Configured forward model with spec (Parameters), observation
(Photometry/Spectroscopy/etc.), and predictor methods.
data (array_like, shape (n_data,)) – Observed data (photometric fluxes or spectra). Units match the model’s
observation configuration. [erg/s/cm²/Hz] for photometry.
noise (array_like, shape (n_data,)) – 1-sigma measurement uncertainties. Same shape and units as data.
data_type (str or None) – Data type indicator: "photometry", "spectroscopy", or
"joint". If None (default), inferred from
model.observation. Explicit values override inference.
data_mask (array_like, bool or None) – Optional boolean mask for censored/non-detections. True = use datum
in likelihood, False = exclude. Default None (use all).
calibration_marginalize (bool, optional) – If True, analytically marginalize over spectroscopic calibration
polynomial coefficients (Chebyshev order 1–cal_n_poly) when
computing spectroscopic log-likelihood. Only applies when
data_type ∈ {"spectroscopy", "joint"}. Follows Prospector
(Johnson et al. 2021). Default False.
cal_n_poly (int, optional) – Number of Chebyshev polynomial coefficients for calibration
marginalization (order 1 through cal_n_poly). Default 3.
cal_prior_sigma (float, optional) – Standard deviation of Gaussian prior on each calibration coefficient.
Default 1.0.
eline_marginalize (bool or None, optional) – Whether to analytically marginalize emission line amplitudes.
None (default) auto-detects from the model’s Spectroscopy
config (checks eline_mode=="marginalized").
eline_prior_type (str or None, optional) – Prior type for emission line marginalization: "flat" (uniform) or
"cloudy" (grid-interpolated from Cloudy models).
None auto-detects from Spectroscopy.eline_prior_type.
Default None.
compile_modes (tuple[str, ...] or str or None, optional) –
Control background JIT compilation during __init__. Accepted values:
None (default) → no background compile; first run() compiles
lazily.
explicit tuple[str,...] (e.g., ("mcmc_nuts",)) → queue exactly
those modes in the background thread.
explicit str (e.g., "mcmc_nuts") → wrap into a 1-tuple
("mcmc_nuts",).
Compile modes are passed to compile(modes=...) and determine which
inference engines are pre-JIT-compiled before the first run() call.
See compile() docstring for valid mode names.
Returns:
Fitter instance with loss function compiled and ready for inference.
JIT-compatibility: Methods in this class are not JIT-compatible because
they perform Python-level branching on method names and manage resources
(thread compilation, caching). The returned loss function and sampler
engines are fully JIT-compiled and reusable across galaxies.
Background compilation: Background compilation is now opt-in via the
compile_modes parameter (default None = no background thread).
The first run() call will compile lazily. Set compile_modes="auto"
or compile_modes=("mcmc_nuts",) to spawn a daemon thread and pre-compile
specified inference modes before run() is called (typically <1s if warm,
or the full compile time on cold XLA). Set TENGRI_NO_BACKGROUND_COMPILE=1
in the environment to disable even when compile_modes is set (test
environments).
Engine caching: Compiled engines are cached on the Model object so that
multiple Fitters created with the same Model but different data reuse the
same XLA programs. Cache key depends on data_type, dimensionality, free
parameter names, and feature flags (emission lines, calibration).
Pre-compile the JIT inference engine ahead of time.
Triggers XLA compilation for all specified modes so that
subsequent fitter.run() calls have zero compilation delay.
Compiled programs are cached both in-memory (this session)
and on disk (/tmp/tengri_jax_cache, survives restarts).
Parameters:
n_iterations (int) – Iteration count for the pre-compilation run. Changing
n_iterations at run time does NOT trigger recompilation
(the iteration count is a dynamic traced value).
n_samples (int) – Compile for this sample count. Changing n_samples
at run time DOES trigger recompilation (array shapes
depend on it).
n_posterior_samples (int) – Compile posterior draw for this many samples.
modes (tuple of str) – Which VI sample modes to pre-compile. Each mode compiles
separately. Default covers MGVI + geoVI update (fastest).
Add "nonlinear_resample" for full geoVI (~56s extra).
mcmc_methods (tuple of str) – MCMC methods to pre-compile. Supported values:
"nuts", "hmc", "dynamic_hmc", "ghmc".
Each call runs the full warmup + chain scan through JIT so
the XLA disk cache is populated before the first user call.
After fitter.compile(mcmc_methods=["nuts"]), a fresh
kernel restart deserializes in <1s instead of ~23s.
n_warmup (int) – Warmup steps used for the MCMC compilation run.
n_burnin (int) – Burn-in steps used for the MCMC compilation run.
n_mcmc_samples (int) – Sample steps used for the MCMC compilation run.
nss (bool) – Pre-compile the NSS (nested slice sampling) step and init
functions. NSS has a ~10–15s cold compile on the first
fitter.run("nss") call; setting nss=True moves that cost
to compile time. data_args is traced so the compiled program
is reused across galaxies with the same model configuration.
Compilation mechanics: Pre-compilation invokes jax.jit on
the forward model’s SED prediction and inference engines, storing
compiled XLA programs to disk. First fitter.run() will skip
XLA overhead by loading pre-compiled kernels. Typical times:
"linear_resample" + "nonlinear_update" ~3s; full modes ~60s;
NUTS ~23s (once per unique model shape).
MCMC cache key: The XLA program is keyed on logdensity_fn_2arg
identity, n_warmup, n_burnin, n_mcmc_samples, and
use_dense. Use the same values here as in fitter.run() to
guarantee a cache hit. Changing galaxy data does not invalidate
the cache (data_args is traced, not static).
>>> fitter=Fitter(model,data,noise)>>> fitter.compile()# ~3s for default VI modes>>> fitter.compile(mcmc_methods=["nuts"])# ~23s, then instant restarts>>> fitter.compile(nss=True)# ~12s, then instant restarts>>> result=fitter.run("mcmc_nuts")# instant after compile>>> result=fitter.run("nss")# instant after compile
Return a hashable signature for cross-galaxy engine reuse.
Combines SEDModel’s compile_signature() with Fitter-specific
parameters that affect the compiled inference engine. Two Fitters
with matching signatures can share the same XLA-compiled engine,
even if they reference different SEDModel instances (as long as
those instances have the same compile_signature).
The signature does NOT include memory_mode, as it does not change
the generated HLO graph — it only affects posterior-chunking
behavior in the analysis layer (see _draw_jit_samples and
_draw_nonlinear_jit_samples). Toggling memory_mode between
“fast” and “low” reuses the same cached engine.
Returns:
Hashable immutable signature suitable for keying into
the module-level _SHARED_ENGINE_CACHE.
Used by _get_or_build_engine to enable cross-galaxy engine reuse
in PopulationFitter and CatalogFitter. The signature is computed
ONCE per Fitter construction and cached to avoid recomputation
in tight loops.
Creates a Fitter per galaxy, sharing the XLA compilation cache.
The first galaxy pays compile cost; subsequent galaxies load
from the persistent XLA cache (milliseconds each).
Works with any inference method — vi (default) gives
the best speed. Also usable for hierarchical individual fits.
Parameters:
batch (list of dict) – Each dict has “flux_obs” and “noise” arrays.
method (str) – Default “vi”. Any method from run().
key (PRNGKey, optional) – Random seed for sampling methods. Default: jax.random.PRNGKey(42).
Parallelization strategy:
- For method="map" with precomputed photometry: uses jax.vmap
to fit all galaxies in a single JIT call (1-2s total).
For MCMC methods with fixed SFH: uses jax.vmap + shared adaptation.
Otherwise: sequential Fitter per galaxy (load from XLA cache).
Compilation caching: All Fitters share the same Model instance,
enabling persistent XLA cache. After first galaxy, subsequent fits
are 10-100× faster depending on method.
Native VI tuning: When method contains "native" and
n_seeds is not explicitly passed, automatically sets n_seeds=5
for better convergence.
Dispatches to the underlying inference backend (variational, MCMC,
point estimation, or nested sampling) and returns a Posterior
object with samples, diagnostics, and derived quantities.
"raytrace", "nuts", "hmc", etc. (all MCMC methods)
"evidence" → "nss"
init_from (Posterior, optional) – Previous inference result to use as warm-start initialization.
The posterior mean is extracted and converted to unbounded space.
Useful for refining results across different methods. Default None.
key (PRNGKey, optional) – JAX random key. Default PRNGKey(42) for reproducibility.
Ignored for deterministic methods ("map", "laplace").
**kwargs –
Method-specific keyword arguments passed to the underlying backend:
VI methods: n_samples, n_kl_iter, tol_kl, sample_mode,
verbose, mirror_samples.
ParameterError – If method is invalid or unrecognized.
RuntimeError – If background JIT compilation failed.
ValueError – If method-specific kwargs are invalid.
Notes
Method selection strategy:
Default ("vi"): geoVI is recommended for high-dimensional problems
(D>50) and population fitting. Captures non-Gaussian posterior geometry.
Exact posterior ("mcmc_nuts"): Use for D≤20 where exact sampling is
feasible and posterior validation is critical.
Fast large-D sampling ("mcmc_raytrace"): Use for D>50 with gradient
access; 250× more robust to noisy gradients than HMC.
Bayesian model comparison ("nss"): Estimates log-evidence for
comparing competing physical models (e.g., different dust laws).
Important gotchas:
VI posterior equivalence: "vi" (NIFTy geoVI) and "vi_native"
(pure JAX geoVI) target the same objective but are NOT posterior-equivalent.
The native version is ~19× faster but produces different posterior shapes
on some problems (e.g., PSD timescale can differ by order of magnitude).
Validate before swapping methods. See
bench/reports/2026-04-17_native_vs_nifty.md.
VIConfig.n_samples doubling: In geoVI, when mirror_samples=True
(default), n_samples=3 produces 6 effective samples (3 + 3 mirrors).
When tuning convergence, think in effective samples.
Ray Tracing step_size scaling: Ray Tracing uses step_size=0.05 by
default for D~137. There is a sharp viability cliff at ~0.06 where
acceptance drops from 80% to 0%. Use smaller step sizes for safety.
Method defaults from file: Default hyperparameters (n_kl_iter,
n_warmup, etc.) are loaded from defaults.toml if available.
Command-line kwargs override file defaults.
MAP provides a quick point estimate; MCMC and VI refine from this
initialization, converging faster than from random initialization.
Reproducibility:
Pass key=jax.random.PRNGKey(seed) to control randomness across runs.
key=None defaults to PRNGKey(42) for reproducibility.
Compile-cache behaviour (smart lean, 2026-05):
run accepts a lean kwarg (default inferred from
tengri.lean() / tengri.persistent() context). With
lean=True (the default), the inference-body cache is
cleared of stale entries — every entry whose
(compile_signature,method) differs from the current call
is dropped, but the entry that matches the current call (if it
exists from a prior identical run) is kept. Forward-model,
log-density, loss, and gradient compiles survive unconditionally.
Implications:
Multi-phase notebooks (MAP → HMC → posterior-predictive)
peak at one inference scan body in RAM, not several.
Catalog loops calling fitter.run(method) repeatedly with
the same model and method pay one compile, not N — without
needing tengri.persistent().
tengri.gc() drops everything including structural caches;
use it between loops that build many different model
configurations.
References
Examples
Example 1: Quick exploration with MAP + geoVI
>>> fitter=Fitter(model,data,noise)>>> result=fitter.run("vi")# geoVI with defaults>>> print(result.summary())
The summary includes:
- Data dimensionality and median signal-to-noise ratio
- Free parameters and latent grid points (ξ) if stochastic SFH
- Parameter names, prior distributions, and bounds
- All available inference methods
Examples
>>> fitter=Fitter(model,data,noise)>>> print(fitter.summary())Fitter data_type: photometry────────────────────────────────────────────────────────────── Data points: 100 Median S/N: 5.2 Parameters: 8 free + 64 latent (ξ)...
Inference results with samples, diagnostics, and derived quantities.
Stores posterior samples (or point estimate for MAP), best-fit parameters,
convergence diagnostics, and provides methods for summary statistics,
derived physical quantities, ArviZ conversion, and refinement via resampling
or additional fitting iterations.
Parameters:
samples (dict or None) – Posterior samples in physical parameter space (optional, set by inference).
params (dict) – Best-fit or posterior mean parameters.
method (str) – Inference method name (e.g., "vi", "mcmc_nuts", "map").
wall_time_s (float) – Total wall-clock runtime in seconds.
Posterior samples in physical parameter space. Each value has shape
(n_samples, …). Keys are parameter names (e.g., "stellar_mass",
"age_gyr", "psd_xi"). None for point estimates (MAP, Laplace,
Pathfinder).
Emission line fluxes. Shape (n_lines,) for MAP, (n_samples, n_lines) for
sampling. None if no emission line fitting/marginalization was enabled.
Flux units match input data [erg/s/cm²].
Derived quantities: The derived property computes stellar mass, SFR,
sSFR, etc. by re-running the forward model on all samples. For MAP results,
returns scalars; for MCMC/VI results, returns arrays (one per sample).
Emission line diagnostics: Methods line_fluxes(), bpt_nii(),
and balmer_decrement() provide astrophysical diagnostics on emission
lines. Require eline_mode!="none" in Spectroscopy config.
Convergence diagnostics: Use check_convergence(), autocorrelation(),
and effective_sample_size() to assess MCMC chain quality.
Resampling and refinement: Use resample() to draw new samples from
the posterior, and refine() to improve results by running additional
inference iterations (requires _fitter).
Uses FFT for O(N log N) efficiency instead of naive O(N * max_lag).
Normalized so that ACF[0] = 1 and ACF[k] ∈ [-1, 1].
Parameters with zero variance (fixed parameters) return zero ACF.
Examples
>>> acf=result.autocorrelation()>>> lag_cutoff=np.argmax(acf["stellar_mass"]<0.05)>>> print(f"ACF drops below 0.05 at lag {lag_cutoff}")
Two autocorrelation time estimates are computed:
- tau_standard: based on standard ACF
- tau_absolute: based on absolute-deviation ACF (robust to mean/variance changes)
The maximum is returned (conservative). ESS = N / tau_max.
Convergence flag uses the criterion N > 5τ_max from Behroozi (2025).
Visual extinction A(V) from the Balmer decrement (Calzetti+2000).
Converts the observed Hα/Hβ ratio to a V-band attenuation
assuming Case B recombination (intrinsic ratio 2.86) and the
Calzetti et al. 2000 [1]_ starburst attenuation law
(\(R_V = 4.05\), \(k(H\alpha) = 2.53\), \(k(H\beta) = 3.61\)).
Returns:
(median,lo_68,hi_68) of A(V) in [mag]. For MAP results
all three values are equal. Negative values are returned
as-is (unphysical, but informative for noisy data — clip
externally if desired).
Returns the posterior distribution of the Balmer decrement
(Hα/Hβ), which is a direct dust attenuation diagnostic.
The intrinsic Case B ratio is 2.86; higher values indicate dust.
Returns:
(median,lo_68,hi_68) of Hα/Hβ (dimensionless ratio).
For MAP results, all three values are equal.
ValueError – If Hα or Hβ fluxes are not available.
Notes
The Balmer decrement (Hα/Hβ flux ratio) is insensitive to stellar
population age and metallicity; deviations from the intrinsic Case B
value of 2.86 (Osterbrock 1989) directly indicate dust attenuation.
A Balmer decrement > 3.0 typically indicates significant dust.
Examples
med,lo,hi=result.balmer_decrement()print(f"Ha/Hb = {med:.2f} [{lo:.2f}, {hi:.2f}]")# Intrinsic Case B = 2.86; excess indicates dust attenuation
BPT-NII ([NII]/Hα vs [OIII]/Hβ) diagram coordinates.
Returns log10 line ratios for each posterior sample.
Returns:
log_nii_ha (ndarray, shape (n_samples,) or scalar) – log10([NII]6584 / Hα). For MAP, returns scalar.
log_oiii_hb (ndarray, shape (n_samples,) or scalar) – log10([OIII]5007 / Hβ). For MAP, returns scalar.
Raises:
ValueError – If emission line fluxes are not available or BPT lines are missing.
Notes
The BPT diagram is a standard AGN/SF diagnostic that uses the ratios
[NII]/Hα (x-axis) and [OIII]/Hβ (y-axis). Non-detections (negative
or zero fluxes) are returned as NaN and will not be plotted.
Diagnostic lines follow Kewley et al. (2001, 2006) conventions.
Examples
x,y=result.bpt_nii()plt.scatter(x,y,alpha=0.3)# Overlay diagnostic lines from starburst/Seyfert boundaries
Convergence criterion: N > 5τ_max for all parameters.
If N < 5τ for any parameter, the chain is flagged as unconverged.
See Behroozi (2025) for justification of the 5τ threshold.
Examples
>>> conv=result.check_convergence()>>> ifconv["all_converged"]:... print("Chain converged!")... else:... print(f"Unconverged: {conv['warnings']}")... print("Run additional samples and use refine()")
For MAP: computed on the single best-fit → dict of scalars.
For NUTS/geoVI: computed on all samples → dict of arrays.
Returns:
Keys: derived quantity names ("stellar_mass", "sfr_100myr", etc.).
For MAP: values are scalars.
For sampling: values are arrays of shape (n_samples,).
Units match the forward model convention (stellar mass in [Msun],
SFR in [Msun/yr]).
This is a cached property — computed on first access and cached thereafter.
Requires _model to be set (populated automatically by Fitter.run()).
For stochastic SFH, unresolved bursts are included in sfr_10myr and
sfr_100myr outputs.
For MAP results, returns a simple string indicating no samples are available.
For sampling methods, tabulates ESS for each parameter alongside median
and 68% credible intervals. Use rhat() for the split-chain
Gelman-Rubin diagnostic.
Estimate effective sample size (ESS) for each parameter.
Uses Sokal’s self-consistent window method (Behroozi 2025):
τ = 1 + 2 Σ ρ(k), truncated at k > 5τ. Takes the max of
standard and absolute-deviation autocorrelation times for
a conservative estimate.
Returns:
Keys: parameter names. Values: ESS (float).
ESS = N / τ, where N is the total number of samples
and τ is the integrated autocorrelation time.
ESS measures the number of independent samples. Low ESS
(< 100 for typical analyses) indicates poor mixing.
The threshold N > 5τ (equivalently ESS > N/5) indicates
adequate sampling for most purposes.
Examples
>>> ess=result.effective_sample_size()>>> print(f"Stellar mass ESS: {ess['stellar_mass']:.0f}")>>> ifess["stellar_mass"]<100:... print("Warning: low ESS, may need more samples")
Rest-frame emission line equivalent widths from posterior samples.
For each emission line, predicts the rest-frame SED for each posterior
sample (or the MAP point estimate) using the attached forward model
and integrates the line flux relative to the local continuum estimated
from sidebands flanking the line. Sign convention follows
tengri.analysis.diagnostics.spectral.equivalent_width():
positive EW for emission, negative for absorption.
Parameters:
window_aa (float, optional) – Half-width of the line integration window [Angstrom]. Default 20.
continuum_width_aa (float, optional) – Width of each sideband used to estimate the continuum [Angstrom].
Sidebands sit at [lambda_0+/-window+/-continuum_width].
Default 50.
Returns:
{line_name:(median_EW,lo_68,hi_68)} in [Angstrom]. For MAP
results, all three values coincide.
ValueError – If eline_fluxes / eline_wavelengths are unavailable, or if
no _model is attached to compute the continuum.
Notes
The continuum is estimated locally per line from the model-predicted
rest-frame SED, not from a precomputed continuum-only grid. The
sideband choice therefore must avoid contamination from neighbouring
emission lines; defaults are tuned for optical BPT lines.
Returns median and 68% credible interval for each emission line.
Returns:
{line_name:(median,lo_68,hi_68)} for each emission line.
Flux units match the input data [erg/s/cm²]. For MAP results, all three
values are the same (single-point estimate).
ValueError – If no emission line fluxes are available. Set
eline_mode="marginalized" or "fitted" in Spectroscopy to enable.
Notes
Each line’s credible interval is computed as the 16th, 50th, and 84th
percentiles of the posterior samples. For single samples (MAP), all three
values coincide.
Reads HDF5 file saved by save(). The _fitter back-reference
is not restored (it is runtime-only). To use refine() or validate(),
set the _fitter attribute manually or reload using model.fit(...).
Provide model to enable derived quantity computation.
Examples
>>> result=Posterior.load("posterior_result.h5",model=model)>>> print(result.summary_table())>>> derived=result.derived# If model is provided
Plot corner (triangle) plot of posterior distributions.
Parameters:
params (list of str, optional) – Parameter names to include. Defaults to all scalar physical params.
Automatically excludes psd_xi (latent field) and constant parameters.
truths (dict, optional) – True values to mark with dashed lines. Keys should match parameter names.
color (str) – Color for this posterior’s contours and histograms.
fig (matplotlib Figure, ndarray of Axes, optional) – If provided, overlay on existing corner plot (for comparing posteriors).
axes (matplotlib Figure, ndarray of Axes, optional) – If provided, overlay on existing corner plot (for comparing posteriors).
label (str, optional) – Legend label for this posterior (appears in legend on diagonal).
Returns:
fig – The corner plot figure.
Return type:
matplotlib Figure
Notes
Creates an N×N triangle plot (lower triangle only). Diagonal shows
1D marginal distributions with KDE overlay and quantile lines.
Off-diagonal shows 2D histograms with credible region contours at
68% and 95%. Includes derived quantities (stellar_mass, sfr_100myr,
sfr_10myr) if _model is available. Derived quantities are
plotted in log10 space.
Plot posterior predictive SED with credible interval.
Draws n_draws parameter samples from the posterior, computes
the rest-frame SED for each, and shades the 16th–84th percentile
band around the median.
Parameters:
n_draws (int) – Number of posterior draws to use for the band. Ignored for
MAP results (plots single SED).
wave_range ((float, float)) – Wavelength range in [Angstrom] to display.
ax (matplotlib Axes, optional) – Axes to plot on. Creates new figure if None.
Returns:
fig – The SED plot figure.
Return type:
matplotlib Figure
Notes
Plots λ F_λ (rest-frame spectral energy density) normalized at 5500 Å.
For MAP results, shows the single best-fit SED.
For sampling methods, draws n_draws random samples and computes
percentiles of the SED over those draws.
Requires _model to be available (set by model.fit()).
Uses log-log axes for visibility across wavelength range.
n_draws (int) – Number of posterior draws. Ignored for MAP (plots single SFH).
ax (matplotlib Axes, optional) – Axes to plot on. Creates new figure if None.
Returns:
fig – The SFH plot figure.
Return type:
matplotlib Figure
Notes
Plots both the stochastic SFH (burst component) and smooth component.
For MAP: shows single best-fit SFH (both stochastic and smooth).
For sampling: draws n_draws random samples and computes 16th–84th
percentile bands.
X-axis: lookback time [Gyr].
Y-axis: SFR [Msun/yr].
Requires _model to be available (set by model.fit()).
Posterior predictive predictions, residuals, and chi^2 distribution.
Pushes posterior draws (or the MAP point estimate) through the
attached forward model’s predict_photometry and reports
per-draw predictions, standardised residuals, and a chi^2
distribution against the supplied data + noise.
Parameters:
data (array_like, shape (n_obs,)) – Observed data the posterior was conditioned on. Same units
as predict_photometry’s output.
n_samples (int, optional) – How many posterior draws to evaluate. None (default)
uses every available draw; for MAP results this is
implicitly 1. For sampling results, draws are selected
via resample() (with replacement) using key.
key (PRNGKey, optional) – JAX PRNG key for resampling. If None, defaults to
jax.random.PRNGKey(0).
Returns:
predictions: shape (N,n_obs),
residuals: (data-prediction)/noise of shape
(N,n_obs),
chi2: per-draw \(\chi^2 = \sum_i r_i^2\), shape
(N,),
chi2_median, chi2_lo, chi2_hi: 16/50/84
percentiles of the chi^2 distribution (scalars).
N is 1 for MAP, otherwise n_samples (or the full
chain length if n_samplesisNone).
This is a deterministic posterior predictive (no extra noise
realisation per draw). For replicated PPCs that draw observation
noise per sample, layer noise*jax.random.normal on top of
predictions.
RuntimeError – If ._fitter is not set (Posterior created outside model.fit/fitter.run).
Notes
Warm-starts the new inference from this posterior’s parameters or samples.
Common use cases: VI → MCMC refinement (exact inference on top of variational
fit), MCMC → different sampler (e.g. raytrace → nuts), or quick method
→ expensive method for publication.
For MAP results, returns the point estimate (repeated n times if n > 1).
For sampling results, draws n indices uniformly from [0, n_samples) with
replacement. Use for Monte Carlo propagation of posterior uncertainty
through forward models.
Examples
>>> key=jax.random.PRNGKey(0)>>> sample=result.resample(key,n=1)# Single draw>>> samples=result.resample(key,n=100)# 100 resamples>>> sfhs=[model.predict_sfh(samples)for...]# Propagate
Splits each parameter chain in half and computes the classical
\(\hat R = \sqrt{\hat V / W}\) against the two halves.
\(\hat R \approx 1.0\) indicates convergence;
\(\hat R > 1.01\) (Vehtari+2021) suggests failure to mix.
Parameters:
exclude_prefixes (tuple of str, optional) – Parameter name prefixes to skip. Default skips psd_xi
(GP latent vector — high-D, not informative per-component).
Returns:
Parameter name → \(\hat R\). Static (zero-variance) and
excluded parameters are dropped.
Serializes samples, params, diagnostics, loss history, and
emission line data. Model and fitter references are NOT saved
(they are non-serializable runtime objects).
Saves to HDF5 format with groups:
- samples: posterior samples (if available)
- params: best-fit or MAP parameters
- loss_history: optimization loss over iterations (if available)
- diagnostics: method-specific convergence metrics
- eline: emission line fluxes, covariances, names, wavelengths
Use load() to restore the Posterior from disk.
Decomposes the model prediction into stellar (attenuated /
intrinsic) + nebular + shock + dust IR + AGN + radio + X-ray
for every posterior sample (or the MAP point estimate) by
running the orchestrator pipeline per draw and reading the
per-component SED arrays each adapter publishes into
state.derived.
Parameters:
wavelength (array_like, optional) – Ignored — kept for backwards compatibility. The orchestrator
uses the model’s SSP wavelength grid; pass-through to a
different grid is no longer supported here. Interpolate the
returned arrays externally if a different grid is needed.
Returns:
Keys: wavelength plus the entries of
Posterior._COMPONENT_KEYS (sed_total,
sed_attenuated, sed_intrinsic, sed_nebular,
sed_shock, sed_dust_ir, sed_agn, sed_radio,
sed_xray). Each component array has shape
(n_wave,) for MAP and (n_samples,n_wave) for
sampling.
Reads the orchestrator’s per-component publications:
sed_dust_attenuated (stellar post-attenuation),
sed_dust_ir, sed_nebular, sed_shock, sed_agn,
sed_radio, sed_xray — each adapter publishes its own
contribution into state.derived. sed_total is the
accumulated state.sed_intrinsic after the chain runs;
sed_intrinsic (stellar pre-attenuation) is reconstructed
from the published lnu_age cube via sum(lnu_age,axis=0).
For Posterior draws this is still a Python loop (one
orchestrator pass per sample); JIT vectorisation across draws
is a separate optimisation.
For MCMC and VI results, credible intervals are 16th and 84th percentiles.
For MAP results, returns point estimates without intervals.
Does not include high-dimensional latent fields (psd_xi).
Return the summary as a string (parity with Parameters.summary_str).
Same content as summary_table() — alias for naming consistency
across the discovery API (tengri.Parameters.summary_str returns
the same kind of string).
The table includes:
- Method name and number of samples (or "MAP")
- Wall-clock time in seconds
- Parameter names with median and credible intervals
- Effective sample size (ESS) for sampling methods
- Method-specific diagnostics (accept rate, divergences, loss, etc.)
- Log evidence (if available from nested sampling)
ArviZ InferenceData object with posterior group containing
all scalar parameters.
Return type:
az.InferenceData
Notes
Requires arviz to be installed: pipinstallarviz.
High-dimensional latent fields (psd_xi) are excluded.
Samples are reshaped to (1, n_samples) format (1 chain).
Use ArviZ tools for advanced visualization and diagnostics
(forest plots, rank plots, etc.).
For MAP results, all parameters become Fixed at the MAP value.
For sampling methods, each parameter gets a Gaussian prior with:
- mean: median of samples
- sigma: standard deviation of samples
- bounds: [min, max] from samples (clipping)
Inherits stochastic and n_grid settings from the original model.
Examples
>>> posterior_params=result.to_param_spec()>>> # Use as starting point for next fit>>> refined=fitter.run("mcmc_nuts",init_from=posterior_params)
Run a short MCMC check and return a validation summary.
Runs n_steps of Ray Tracing (or NUTS for D≤20) from this
posterior’s MAP estimate, then computes the marginal overlap
between this posterior and the MCMC check posterior for each
parameter.
Parameters:
n_steps (int) – Number of MCMC steps. Default 200 (quick sanity check).
**kwargs – Forwarded to the MCMC run.
Returns:
Keys: "mcmc_result" (Posterior from MCMC check),
"overlap" (dict of float per parameter, 1.0 = perfect overlap,
0.0 = no overlap), "passed" (bool, True when all overlaps > 0.5).
Validation checks whether a quick MCMC run agrees with the current
posterior (typically from VI or MAP). High overlap (> 0.5) indicates
the method is reliable; low overlap suggests the posterior may be
biased or misspecified.
Overlap is computed as the histogram intersection at each parameter.
For sampling methods (VI, geoVI), validates the approximate posterior.
For MCMC methods, serves as a sanity check for chain convergence.
Manages population-level inference via hierarchical VI, learning the shared
PSD hyperparameters (σ_PSD, τ_PSD) across a population of galaxies while
preserving per-galaxy latent fields and physical parameters.
Parameters:
model_factory (callable) – Function(psd_sigma, psd_tau_myr) → Model.
Creates a model with the given PSD params. All other params
(SFH, dust, etc.) come from the model’s Parameters.
galaxies (list of dict) – Each dict has ‘flux_obs’, ‘noise’, and optionally ‘spec_obs’,
‘spec_noise’, ‘wave_spec’.
psd_sigma_prior (tuple) – (lo, hi) for uniform prior on σ_PSD.
psd_tau_prior (tuple) – (lo, hi) for uniform prior on τ_PSD (Myr).
Wraps all hierarchical inference methods (EVI, VI, MCMC) with automatic
initialization via per-galaxy MAP estimation. The class builds a single
flat parameter vector [φ_shared, ξ_1, θ_1, …, ξ_N, θ_N] and optimizes
it via the specified method.
Unlike Fitter.run(), PopulationFitter.run() does not support
warm-start initialization via init_from because hierarchical
inference methods (EVI, CFM-based geoVI, Ray Tracing) initialize
per-galaxy parameters via MAP estimation automatically. The init_from
parameter is not meaningful in the hierarchical context.
Automatic initialization: all methods initialize per-galaxy parameters
via MAP estimation before starting the hierarchical inference. First call
may be slow due to JIT compilation. Subsequent calls are fast.
Approximate runtime: ~30 seconds for 10 galaxies on CPU (method-dependent).
Unlike CatalogFitter,
PopulationFitter does not expose an n_pad argument:
the hierarchical population field couples all N galaxies, so
padding with dummy galaxies would contribute spurious prior
mass to the population hyperparameters (e.g. psd_sigma,
psd_tau_myr) even with masked likelihoods. To amortize XLA
compile cost across notebook restarts, slurm tasks, or sweeps
over different catalog sizes, rely on the persistent
compilation cache instead — see
tengri.enable_persistent_cache() and
docs/inference/compilation_cache.md.
This class holds the posterior distribution over shared PSD hyperparameters
(σ_PSD, τ_PSD) inferred across a population of galaxies, along with optional
per-galaxy individual posteriors.
Parameters:
shared_samples (dict) – Posterior samples for shared PSD params. Keys are param names (e.g.,
‘psd_sigma’, ‘psd_tau_myr’), values are arrays of shape (n_samples,).
shared_params (dict) – Posterior mean of shared PSD params (computed from shared_samples).
individual_samples (list of dict, optional) – Per-galaxy posterior samples. Each element is a dict with per-galaxy
parameter names as keys. If None, individual posteriors are not stored
(for memory efficiency).
method (str) – Inference method used (e.g., “Hierarchical EVI (JIT)”).
wall_time_s (float) – Total wall-clock time for inference.
diagnostics (dict) – Method-specific diagnostics (e.g., number of iterations, convergence info).
Notes
This dataclass is the return type for PopulationFitter.run(). Access
population posteriors via shared_samples and shared_params, and
per-galaxy posteriors via the individual property (returns a list of
lightweight objects with .samples and .params attributes).
Examples
>>> result=fitter.run("vi",n_iterations=20)>>> result.summary()# Median and 68% credible intervals>>> ax=result.plot_population(params=("sfh_field_psd_sigma","sfh_field_psd_tau_myr"))
Computes split-Rhat and effective sample size for the shared PSD
block, and (when individual_samples is present) aggregates
per-galaxy diagnostics into population-level summaries.
Parameters:
exclude_prefixes (tuple of str, optional) – Parameter-name prefixes to skip when computing per-galaxy
diagnostics. Default ("psd_xi",) skips the GP latent
fields, which carry one chain entry per grid point and
inflate dict size without adding interpretable information.
Returns:
Two-level structure:
{"shared":{param_name:{"rhat":float,"ess":float},...},"per_galaxy":{# only if individual_samples is setparam_name:{"rhat_p50":float,# median across galaxies"rhat_p90":float,"rhat_max":float,"ess_p50":float,"ess_min":float,"n_galaxies":int,},...},}
Use "per_galaxy" to spot a single galaxy whose chain has
stalled — a high rhat_max with low rhat_p50 is the
signature.
Reuses tengri.analysis.diagnostics.rhat() and
tengri.analysis.diagnostics.effective_sample_size(). Static
(zero-variance) parameters are dropped silently.
sample_hmc (bool) – If True, use HMC instead of ray tracing.
integrator (str) – Leapfrog integrator scheme: "dkd" (Drift-Kick-Drift, default)
or "kdk" (Kick-Drift-Kick). KDK is the time-reversed partner
of DKD. Both are symplectic and second-order. DKD matches
Behroozi’s reference implementation.