Inference

Classes for fitting models to data, sampling posteriors, and running hierarchical inference across galaxy populations.

Fitter

class tengri.Fitter(model, data, noise, data_type=None, data_mask=None, calibration_marginalize=False, cal_n_poly=3, cal_prior_sigma=1.0, eline_marginalize=None, eline_prior_type=None, likelihood=None, auto_protocol_likelihood=True, use_orchestrator=False, compile_modes=None)[source]

Bases: object

Inference engine for differentiable SED fitting with flexible method dispatch.

Separates inference strategy from the forward model by building a loss function from the SEDModel’s predictions and the Parameters’ priors, then running the chosen optimizer/sampler. Supports point estimation (MAP, Laplace), gradient-free and gradient-based sampling (ESS, NUTS, Ray Tracing, MCMC), variational inference (geoVI, MGVI), and nested sampling (NSS) via a unified run(method) interface.

Parameters:
  • model (SEDModel) – Configured forward model with spec (Parameters), observation (Photometry/Spectroscopy/etc.), and predictor methods.

  • data (array_like, shape (n_data,)) – Observed data (photometric fluxes or spectra). Units match the model’s observation configuration. [erg/s/cm²/Hz] for photometry.

  • noise (array_like, shape (n_data,)) – 1-sigma measurement uncertainties. Same shape and units as data.

  • data_type (str or None) – Data type indicator: "photometry", "spectroscopy", or "joint". If None (default), inferred from model.observation. Explicit values override inference.

  • data_mask (array_like, bool or None) – Optional boolean mask for censored/non-detections. True = use datum in likelihood, False = exclude. Default None (use all).

  • calibration_marginalize (bool, optional) – If True, analytically marginalize over spectroscopic calibration polynomial coefficients (Chebyshev order 1–cal_n_poly) when computing spectroscopic log-likelihood. Only applies when data_type ∈ {"spectroscopy", "joint"}. Follows Prospector (Johnson et al. 2021). Default False.

  • cal_n_poly (int, optional) – Number of Chebyshev polynomial coefficients for calibration marginalization (order 1 through cal_n_poly). Default 3.

  • cal_prior_sigma (float, optional) – Standard deviation of Gaussian prior on each calibration coefficient. Default 1.0.

  • eline_marginalize (bool or None, optional) – Whether to analytically marginalize emission line amplitudes. None (default) auto-detects from the model’s Spectroscopy config (checks eline_mode == "marginalized").

  • eline_prior_type (str or None, optional) – Prior type for emission line marginalization: "flat" (uniform) or "cloudy" (grid-interpolated from Cloudy models). None auto-detects from Spectroscopy.eline_prior_type. Default None.

  • compile_modes (tuple[str, ...] or str or None, optional) –

    Control background JIT compilation during __init__. Accepted values:

    • None (default) → no background compile; first run() compiles lazily.

    • "auto" → inspect spec.stochastic and data_type to select sensible defaults: stochastic → ("linear_resample", "nonlinear_update") (VI modes); non-stochastic photometry → ("mcmc_nuts",); otherwise → ("mcmc_nuts",).

    • explicit tuple[str, ...] (e.g., ("mcmc_nuts",)) → queue exactly those modes in the background thread.

    • explicit str (e.g., "mcmc_nuts") → wrap into a 1-tuple ("mcmc_nuts",).

    Compile modes are passed to compile(modes=...) and determine which inference engines are pre-JIT-compiled before the first run() call. See compile() docstring for valid mode names.

Returns:

Fitter instance with loss function compiled and ready for inference.

Return type:

Fitter

model

Reference to the input forward model.

Type:

SEDModel

data

Input data as JAX array.

Type:

ndarray, shape (n_data,)

noise

Input noise as JAX array.

Type:

ndarray, shape (n_data,)

data_type

Resolved data type ("photometry", "spectroscopy", "joint").

Type:

str

spec

Reference to model.spec.

Type:

Parameters

Notes

JIT-compatibility: Methods in this class are not JIT-compatible because they perform Python-level branching on method names and manage resources (thread compilation, caching). The returned loss function and sampler engines are fully JIT-compiled and reusable across galaxies.

Background compilation: Background compilation is now opt-in via the compile_modes parameter (default None = no background thread). The first run() call will compile lazily. Set compile_modes="auto" or compile_modes=("mcmc_nuts",) to spawn a daemon thread and pre-compile specified inference modes before run() is called (typically <1s if warm, or the full compile time on cold XLA). Set TENGRI_NO_BACKGROUND_COMPILE=1 in the environment to disable even when compile_modes is set (test environments).

Engine caching: Compiled engines are cached on the Model object so that multiple Fitters created with the same Model but different data reuse the same XLA programs. Cache key depends on data_type, dimensionality, free parameter names, and feature flags (emission lines, calibration).

References

Examples

Fit a single galaxy with geoVI (default):

>>> from tengri import SEDModel, Fitter, Parameters
>>> model = SEDModel(Parameters())
>>> data = jnp.array([1.2, 0.8, 0.5])  # photometric fluxes
>>> noise = jnp.array([0.1, 0.08, 0.06])
>>> fitter = Fitter(model, data, noise)
>>> result = fitter.run("vi", n_samples=100)
>>> print(result.params)

Fit with warm-start from MAP:

>>> result_map = fitter.run("map", n_steps=1000)
>>> result_mcmc = fitter.run("mcmc_nuts", init_from=result_map, n_warmup=500)

See the docstring of run() for all available methods and their options.

compile(*, n_iterations=15, n_samples=3, n_posterior_samples=200, modes=('linear_resample', 'nonlinear_update'), mcmc_methods=(), n_warmup=300, n_burnin=100, n_mcmc_samples=100, nss=False, verbose=True)[source]

Pre-compile the JIT inference engine ahead of time.

Triggers XLA compilation for all specified modes so that subsequent fitter.run() calls have zero compilation delay. Compiled programs are cached both in-memory (this session) and on disk (/tmp/tengri_jax_cache, survives restarts).

Parameters:
  • n_iterations (int) – Iteration count for the pre-compilation run. Changing n_iterations at run time does NOT trigger recompilation (the iteration count is a dynamic traced value).

  • n_samples (int) – Compile for this sample count. Changing n_samples at run time DOES trigger recompilation (array shapes depend on it).

  • n_posterior_samples (int) – Compile posterior draw for this many samples.

  • modes (tuple of str) – Which VI sample modes to pre-compile. Each mode compiles separately. Default covers MGVI + geoVI update (fastest). Add "nonlinear_resample" for full geoVI (~56s extra).

  • mcmc_methods (tuple of str) – MCMC methods to pre-compile. Supported values: "nuts", "hmc", "dynamic_hmc", "ghmc". Each call runs the full warmup + chain scan through JIT so the XLA disk cache is populated before the first user call. After fitter.compile(mcmc_methods=["nuts"]), a fresh kernel restart deserializes in <1s instead of ~23s.

  • n_warmup (int) – Warmup steps used for the MCMC compilation run.

  • n_burnin (int) – Burn-in steps used for the MCMC compilation run.

  • n_mcmc_samples (int) – Sample steps used for the MCMC compilation run.

  • nss (bool) – Pre-compile the NSS (nested slice sampling) step and init functions. NSS has a ~10–15s cold compile on the first fitter.run("nss") call; setting nss=True moves that cost to compile time. data_args is traced so the compiled program is reused across galaxies with the same model configuration.

  • verbose (bool) – Print compilation progress.

Return type:

self

Notes

Compilation mechanics: Pre-compilation invokes jax.jit on the forward model’s SED prediction and inference engines, storing compiled XLA programs to disk. First fitter.run() will skip XLA overhead by loading pre-compiled kernels. Typical times: "linear_resample" + "nonlinear_update" ~3s; full modes ~60s; NUTS ~23s (once per unique model shape).

MCMC cache key: The XLA program is keyed on logdensity_fn_2arg identity, n_warmup, n_burnin, n_mcmc_samples, and use_dense. Use the same values here as in fitter.run() to guarantee a cache hit. Changing galaxy data does not invalidate the cache (data_args is traced, not static).

JIT-compatible: yes — internally calls JIT-compiled JAX functions.

Example

>>> fitter = Fitter(model, data, noise)
>>> fitter.compile()  # ~3s for default VI modes
>>> fitter.compile(mcmc_methods=["nuts"])  # ~23s, then instant restarts
>>> fitter.compile(nss=True)  # ~12s, then instant restarts
>>> result = fitter.run("mcmc_nuts")  # instant after compile
>>> result = fitter.run("nss")  # instant after compile
compile_signature() tuple[source]

Return a hashable signature for cross-galaxy engine reuse.

Combines SEDModel’s compile_signature() with Fitter-specific parameters that affect the compiled inference engine. Two Fitters with matching signatures can share the same XLA-compiled engine, even if they reference different SEDModel instances (as long as those instances have the same compile_signature).

The signature does NOT include memory_mode, as it does not change the generated HLO graph — it only affects posterior-chunking behavior in the analysis layer (see _draw_jit_samples and _draw_nonlinear_jit_samples). Toggling memory_mode between “fast” and “low” reuses the same cached engine.

Returns:

Hashable immutable signature suitable for keying into the module-level _SHARED_ENGINE_CACHE.

Return type:

tuple

Notes

Used by _get_or_build_engine to enable cross-galaxy engine reuse in PopulationFitter and CatalogFitter. The signature is computed ONCE per Fitter construction and cached to avoid recomputation in tight loops.

fit_batch(batch, *, method='vi', key=None, verbose=True, **kwargs)[source]

Fit a batch of galaxies efficiently.

Creates a Fitter per galaxy, sharing the XLA compilation cache. The first galaxy pays compile cost; subsequent galaxies load from the persistent XLA cache (milliseconds each).

Works with any inference method — vi (default) gives the best speed. Also usable for hierarchical individual fits.

Parameters:
  • batch (list of dict) – Each dict has “flux_obs” and “noise” arrays.

  • method (str) – Default “vi”. Any method from run().

  • key (PRNGKey, optional) – Random seed for sampling methods. Default: jax.random.PRNGKey(42).

  • verbose (bool) – Print progress. Default: True.

  • **kwargs – Passed to run() (n_iterations, n_samples, n_seeds, etc).

Returns:

Inference results for each galaxy, in order.

Return type:

list of Posterior

Notes

Parallelization strategy: - For method="map" with precomputed photometry: uses jax.vmap

to fit all galaxies in a single JIT call (1-2s total).

  • For MCMC methods with fixed SFH: uses jax.vmap + shared adaptation.

  • Otherwise: sequential Fitter per galaxy (load from XLA cache).

Compilation caching: All Fitters share the same Model instance, enabling persistent XLA cache. After first galaxy, subsequent fits are 10-100× faster depending on method.

Native VI tuning: When method contains "native" and n_seeds is not explicitly passed, automatically sets n_seeds=5 for better convergence.

Examples

Batch fit 100 galaxies:

>>> batch = [{"flux_obs": f, "noise": n} for f, n in zip(fluxes, noises)]
>>> results = fitter.fit_batch(batch, method="vi")
>>> # First: ~2s compile. Rest: ~2ms each. Total: ~0.2s per galaxy.

Warm-start from MAP:

>>> results_map = fitter.fit_batch(batch, method="map", n_steps=500)
>>> results_vi = fitter.fit_batch(batch, method="vi", init_from=results_map)
run(method: str = 'vi_nonlinear_fast', *, init_from=None, key=None, **kwargs)[source]

Run inference using the specified method.

Dispatches to the underlying inference backend (variational, MCMC, point estimation, or nested sampling) and returns a Posterior object with samples, diagnostics, and derived quantities.

Parameters:
  • method (str, optional) –

    Inference method (case-sensitive). Default "vi".

    Variational Inference (VI)

    • "vi" — geoVI via NIFTy (nonlinear, default for D>20)

    • "vi_linear" — MGVI via NIFTy (linearized Gaussian)

    • "vi_nifty_fast" — geoVI fast path (~35% faster, no logging)

    • "vi_nifty_fast_linear" — MGVI fast path (~35% faster, no logging)

    • "vi_native" — Native JAX geoVI (experimental; ~19× faster than NIFTy)

    • "vi_native_linear" — Native JAX MGVI (experimental)

    MCMC Sampling

    • "mcmc_nuts" — NUTS via BlackJAX (default for D≤20; exact posterior)

    • "mcmc_raytrace" — Ray Tracing (Behroozi 2025; O(1) gradient cost)

    • "mcmc" — Auto: NUTS (D≤20) or Ray Tracing (D>20)

    • "mcmc_hmc" — Standard HMC (fixed trajectory length)

    • "mcmc_dynamic_hmc" — Dynamic HMC (adaptive trajectory)

    • "mcmc_ghmc" — Generalized HMC (partial momentum refresh)

    • "mcmc_mclmc" — MCLMC (O(1) grad/sample, biased)

    • "mcmc_adjusted_mclmc" — MCLMC + Metropolis correction

    • "mcmc_ess" — Elliptical Slice Sampling (gradient-free)

    Point Estimation & Approximations

    • "map" — MAP optimization (Adam by default)

    • "laplace" — Laplace approximation (Gaussian posterior at MAP)

    • "pathfinder" — L-BFGS trajectory + sequence of Gaussians (Zhang+2022)

    Model Comparison (Bayesian Evidence)

    • "nss" — Nested Slice Sampling (exact Z, D≤30)

    Automatic Selection

    • "auto" — NUTS (D≤20) or geoVI (D>20) based on dimensionality

    Deprecated Aliases (still work, emit DeprecationWarning):

    • "vi_nifty", "geovi", "fast_geovi", "nifty_geovi""vi"

    • "vi_nifty_linear", "mgvi", "fast_mgvi", "nifty_mgvi""vi_linear"

    • "native_geovi""vi_native"

    • "native_mgvi", "native_evi""vi_native_linear"

    • "raytrace", "nuts", "hmc", etc. (all MCMC methods)

    • "evidence""nss"

  • init_from (Posterior, optional) – Previous inference result to use as warm-start initialization. The posterior mean is extracted and converted to unbounded space. Useful for refining results across different methods. Default None.

  • key (PRNGKey, optional) – JAX random key. Default PRNGKey(42) for reproducibility. Ignored for deterministic methods ("map", "laplace").

  • **kwargs

    Method-specific keyword arguments passed to the underlying backend:

    • VI methods: n_samples, n_kl_iter, tol_kl, sample_mode, verbose, mirror_samples.

    • MCMC methods: n_steps, n_warmup, thin, step_size, mass_matrix, adapt_step_size, verbose.

    • MAP/Laplace: n_steps, step_size, lr, verbose.

    • Pathfinder: n_steps, n_init, step_size, verbose.

    • NSS: n_live, n_batch, slice_width, verbose.

    See backend docstrings for full option documentation.

Returns:

Inference results object with attributes:

  • samples : dict or None — Posterior samples (None for MAP).

  • params : dict — Best-fit or posterior mean parameters.

  • method : str — Method used.

  • diagnostics : dict — Convergence/quality metrics.

  • log_evidence : float or None — Bayesian evidence (NSS only).

  • wall_time_s : float — Total runtime.

The Posterior also has derived quantity methods: derived, summary(), to_arviz(), refine(), etc.

Return type:

Posterior

Raises:
  • ParameterError – If method is invalid or unrecognized.

  • RuntimeError – If background JIT compilation failed.

  • ValueError – If method-specific kwargs are invalid.

Notes

Method selection strategy:

  • Default ("vi"): geoVI is recommended for high-dimensional problems (D>50) and population fitting. Captures non-Gaussian posterior geometry.

  • Exact posterior ("mcmc_nuts"): Use for D≤20 where exact sampling is feasible and posterior validation is critical.

  • Fast large-D sampling ("mcmc_raytrace"): Use for D>50 with gradient access; 250× more robust to noisy gradients than HMC.

  • Bayesian model comparison ("nss"): Estimates log-evidence for comparing competing physical models (e.g., different dust laws).

Important gotchas:

  • VI posterior equivalence: "vi" (NIFTy geoVI) and "vi_native" (pure JAX geoVI) target the same objective but are NOT posterior-equivalent. The native version is ~19× faster but produces different posterior shapes on some problems (e.g., PSD timescale can differ by order of magnitude). Validate before swapping methods. See bench/reports/2026-04-17_native_vs_nifty.md.

  • VIConfig.n_samples doubling: In geoVI, when mirror_samples=True (default), n_samples=3 produces 6 effective samples (3 + 3 mirrors). When tuning convergence, think in effective samples.

  • Ray Tracing step_size scaling: Ray Tracing uses step_size=0.05 by default for D~137. There is a sharp viability cliff at ~0.06 where acceptance drops from 80% to 0%. Use smaller step sizes for safety.

  • Method defaults from file: Default hyperparameters (n_kl_iter, n_warmup, etc.) are loaded from defaults.toml if available. Command-line kwargs override file defaults.

Warm-starting from MAP:

Fitting often proceeds in stages:

>>> result_map = fitter.run("map", n_steps=1500)
>>> result_mcmc = fitter.run("mcmc_nuts", init_from=result_map, n_warmup=500)
>>> result_vi = fitter.run("vi", init_from=result_map, n_samples=100)

MAP provides a quick point estimate; MCMC and VI refine from this initialization, converging faster than from random initialization.

Reproducibility:

Pass key=jax.random.PRNGKey(seed) to control randomness across runs. key=None defaults to PRNGKey(42) for reproducibility.

Compile-cache behaviour (smart lean, 2026-05):

run accepts a lean kwarg (default inferred from tengri.lean() / tengri.persistent() context). With lean=True (the default), the inference-body cache is cleared of stale entries — every entry whose (compile_signature, method) differs from the current call is dropped, but the entry that matches the current call (if it exists from a prior identical run) is kept. Forward-model, log-density, loss, and gradient compiles survive unconditionally. Implications:

  • Multi-phase notebooks (MAP → HMC → posterior-predictive) peak at one inference scan body in RAM, not several.

  • Catalog loops calling fitter.run(method) repeatedly with the same model and method pay one compile, not N — without needing tengri.persistent().

tengri.gc() drops everything including structural caches; use it between loops that build many different model configurations.

References

Examples

Example 1: Quick exploration with MAP + geoVI

>>> fitter = Fitter(model, data, noise)
>>> result = fitter.run("vi")  # geoVI with defaults
>>> print(result.summary())

Example 2: Exact posterior with NUTS (small-D)

>>> result = fitter.run("mcmc_nuts", n_warmup=500, n_steps=2000)
>>> samples = result.samples["stellar_mass"]
>>> print(f"M_star = {jnp.median(samples):.2e} Msun")

Example 3: Warm-start MCMC from MAP

>>> result_map = fitter.run("map", n_steps=1500)
>>> result_mcmc = fitter.run("mcmc_nuts", init_from=result_map, n_warmup=300, n_steps=1000)

Example 4: Nested sampling for Bayesian model comparison

>>> result_nss = fitter.run("nss", n_live=100)
>>> log_z = result_nss.log_evidence
>>> print(f"log(Z) = {log_z:.2f}")  # Use for Bayes factors

Example 5: Using ``”auto”`` method for unknown dimensionality

>>> result = fitter.run("auto")  # NUTS if D≤20, VI if D>20
summary() str[source]

Return a human-readable summary of the fitting problem.

Returns:

Formatted summary showing data shape, free parameters, priors, bounds, and available inference methods.

Return type:

str

Notes

The summary includes: - Data dimensionality and median signal-to-noise ratio - Free parameters and latent grid points (ξ) if stochastic SFH - Parameter names, prior distributions, and bounds - All available inference methods

Examples

>>> fitter = Fitter(model, data, noise)
>>> print(fitter.summary())
Fitter  data_type: photometry
──────────────────────────────────────────────────────────────
  Data points: 100
  Median S/N:  5.2
  Parameters:  8 free + 64 latent (ξ)
...

Posterior

class tengri.Posterior(samples: dict | None, params: dict, method: str, wall_time_s: float, diagnostics: dict, loss_history: Array | None = None, log_evidence: float | None = None, _model: object = None, _fitter: object = None, eline_fluxes: Array | None = None, eline_flux_cov: Array | None = None, eline_names: tuple | None = None, eline_wavelengths: Array | None = None)[source]

Bases: object

Inference results with samples, diagnostics, and derived quantities.

Stores posterior samples (or point estimate for MAP), best-fit parameters, convergence diagnostics, and provides methods for summary statistics, derived physical quantities, ArviZ conversion, and refinement via resampling or additional fitting iterations.

Parameters:
  • samples (dict or None) – Posterior samples in physical parameter space (optional, set by inference).

  • params (dict) – Best-fit or posterior mean parameters.

  • method (str) – Inference method name (e.g., "vi", "mcmc_nuts", "map").

  • wall_time_s (float) – Total wall-clock runtime in seconds.

  • diagnostics (dict) – Method-specific convergence metrics.

  • loss_history (ndarray or None) – Optimization loss values (optimization methods only).

  • log_evidence (float or None) – Log Bayesian evidence (NSS only).

  • _model (SEDModel, optional) – Forward model reference.

  • _fitter (Fitter, optional) – Fitter reference for refinement methods.

  • eline_fluxes (ndarray or None) – Emission line fluxes [erg/s/cm²].

  • eline_flux_cov (ndarray or None) – Emission line flux covariance.

  • eline_names (tuple or None) – Emission line identifiers.

  • eline_wavelengths (ndarray or None) – Rest-frame vacuum wavelengths [Angstrom].

Returns:

Posterior instance with results populated.

Return type:

Posterior

samples

Posterior samples in physical parameter space. Each value has shape (n_samples, …). Keys are parameter names (e.g., "stellar_mass", "age_gyr", "psd_xi"). None for point estimates (MAP, Laplace, Pathfinder).

Type:

dict or None

params

Best-fit (MAP for point estimation) or posterior mean parameters in physical space. Same keys as samples (without "psd_xi" latent field).

Type:

dict

method

Inference method name (e.g., "vi", "mcmc_nuts", "map").

Type:

str

wall_time_s

Total wall-clock runtime in seconds, including compilation and sampling.

Type:

float

diagnostics

Method-specific convergence and quality metrics. Contents vary by method:

  • VI methods: {"kl_iter": int, "kl_final": float}, etc.

  • NUTS: {"n_divergent": int, "accept_rate": float}, etc.

  • Ray Tracing: {"accept_rate": float, "step_size": float}, etc.

  • MAP: {"final_loss": float, "n_steps": int}, etc.

  • NSS: {"n_live": int, "log_evidence_err": float}, etc.

Type:

dict

loss_history

Optimization loss values over iterations (MAP/Laplace/Pathfinder only). Shape (n_iterations,). None for sampling methods.

Type:

ndarray or None

log_evidence

Bayesian evidence log(Z) integral (NSS only). None for other methods. Used for model comparison via Bayes factors.

Type:

float or None

_model

Reference to the forward model. Required for computing derived quantities (stellar mass, SFR, sSFR, etc.). Set by Fitter.run() automatically.

Type:

SEDModel, optional

_fitter

Reference to the Fitter instance. Enables refine() and other refinement methods. Set by Fitter.run() automatically.

Type:

Fitter, optional

eline_fluxes

Emission line fluxes. Shape (n_lines,) for MAP, (n_samples, n_lines) for sampling. None if no emission line fitting/marginalization was enabled. Flux units match input data [erg/s/cm²].

Type:

ndarray or None

eline_flux_cov

Posterior covariance of emission line fluxes. Shape (n_lines, n_lines) for MAP, (n_samples, n_lines, n_lines) for sampling. None if unavailable.

Type:

ndarray or None

eline_names

Emission line identifiers (e.g., ("Halpha", "Hbeta", ...)) matching eline_fluxes column order.

Type:

tuple or None

eline_wavelengths

Rest-frame vacuum wavelengths [Angstrom] of emission lines, matching eline_fluxes column order.

Type:

ndarray or None

Notes

Derived quantities: The derived property computes stellar mass, SFR, sSFR, etc. by re-running the forward model on all samples. For MAP results, returns scalars; for MCMC/VI results, returns arrays (one per sample).

Emission line diagnostics: Methods line_fluxes(), bpt_nii(), and balmer_decrement() provide astrophysical diagnostics on emission lines. Require eline_mode != "none" in Spectroscopy config.

Convergence diagnostics: Use check_convergence(), autocorrelation(), and effective_sample_size() to assess MCMC chain quality.

Resampling and refinement: Use resample() to draw new samples from the posterior, and refine() to improve results by running additional inference iterations (requires _fitter).

See also

Fitter.run

Returns Posterior with all attributes populated.

Fitter

Primary interface for inference.

Examples

Basic usage:

>>> result = fitter.run("mcmc_nuts")  # Returns Posterior
>>> print(result.summary_table())
>>> params_phys = result.params
>>> samples = result.samples

Derived quantities:

>>> derived = result.derived
>>> stellar_masses = derived["stellar_mass"]  # Shape (n_samples,)
>>> med, lo, hi = np.percentile(stellar_masses, [50, 16, 84])

Emission line diagnostics:

>>> fluxes = result.line_fluxes()
>>> ha_med, ha_lo, ha_hi = fluxes["Halpha"]
>>> x, y = result.bpt_nii()  # BPT diagram coordinates
>>> plt.scatter(x, y, alpha=0.3)

Convergence checks:

>>> converged = result.check_convergence()
>>> ess = result.effective_sample_size()
>>> print(f"Effective sample size: {ess['stellar_mass']:.0f}")

Refinement via resampling:

>>> refined_samples = result.resample(key, n=50)  # Resample with replacement
>>> refined = fitter.run("vi", init_from=result)  # Refine posterior
agn_fraction(wavelength=None) Array[source]

Wavelength-resolved AGN luminosity fraction.

Returns the median over posterior draws of L_agn(λ) / L_total(λ) (or the MAP point estimate as a single curve).

Parameters:

wavelength (array_like, optional) – Rest-frame wavelength grid. Defaults to the model’s grid.

Returns:

Median AGN fraction at each wavelength.

Return type:

ndarray, shape (n_wave,)

Notes

Wraps sed_components(). To get full credible intervals on the fraction, call sed_components directly and compute percentiles on sed_agn / sed_total.

autocorrelation(max_lag: int | None = None) dict[source]

Compute autocorrelation function for each scalar parameter.

Parameters:

max_lag (int, optional) – Maximum lag. Default: n_samples // 2.

Returns:

Keys: parameter names. Values: ndarray of autocorrelation from lag 0 to max_lag. ACF[0] = 1.0 by definition.

Return type:

dict

Notes

Uses FFT for O(N log N) efficiency instead of naive O(N * max_lag). Normalized so that ACF[0] = 1 and ACF[k] ∈ [-1, 1]. Parameters with zero variance (fixed parameters) return zero ACF.

Examples

>>> acf = result.autocorrelation()
>>> lag_cutoff = np.argmax(acf["stellar_mass"] < 0.05)
>>> print(f"ACF drops below 0.05 at lag {lag_cutoff}")
autocorrelation_time() dict[source]

Estimate integrated autocorrelation time for each parameter.

Uses Sokal’s self-consistent window method with both standard and absolute-deviation modes (Behroozi 2025).

Returns:

Keys: parameter names. Values: dict with 'tau_standard', 'tau_absolute', 'tau_max' (integrated autocorrelation time), 'ess' (effective sample size), 'chain_converged' (bool, True if N > 5τ_max).

Return type:

dict

Notes

Two autocorrelation time estimates are computed: - tau_standard: based on standard ACF - tau_absolute: based on absolute-deviation ACF (robust to mean/variance changes) The maximum is returned (conservative). ESS = N / tau_max. Convergence flag uses the criterion N > 5τ_max from Behroozi (2025).

Examples

>>> tau_dict = result.autocorrelation_time()
>>> for param, info in tau_dict.items():
...     converged = info["chain_converged"]
...     print(
...         f"{param}: tau_max={info['tau_max']:.1f}, "
...         f"ESS={info['ess']:.0f}, converged={converged}"
...     )
balmer_av() tuple[float, float, float][source]

Visual extinction A(V) from the Balmer decrement (Calzetti+2000).

Converts the observed Hα/Hβ ratio to a V-band attenuation assuming Case B recombination (intrinsic ratio 2.86) and the Calzetti et al. 2000 [1]_ starburst attenuation law (\(R_V = 4.05\), \(k(H\alpha) = 2.53\), \(k(H\beta) = 3.61\)).

Returns:

(median, lo_68, hi_68) of A(V) in [mag]. For MAP results all three values are equal. Negative values are returned as-is (unphysical, but informative for noisy data — clip externally if desired).

Return type:

tuple

Raises:

ValueError – If Hα or Hβ fluxes are not available.

Notes

\[\begin{split}E(B-V) &= \frac{\log_{10}\!\left(R_{\rm obs}/2.86\right)} {0.4\,\left(k(H\beta) - k(H\alpha)\right)} \\ A(V) &= R_V \cdot E(B-V)\end{split}\]

For Calzetti+2000: \(0.4 \cdot (k(H\beta) - k(H\alpha)) = 0.432\), \(R_V = 4.05\), so \(A(V) \approx 9.375 \, \log_{10}(R_{\rm obs}/2.86)\).

References

Examples

>>> av_med, av_lo, av_hi = result.balmer_av()
balmer_decrement() tuple[float, float, float][source]

Observed Hα/Hβ ratio from posterior line fluxes.

Returns the posterior distribution of the Balmer decrement (Hα/Hβ), which is a direct dust attenuation diagnostic. The intrinsic Case B ratio is 2.86; higher values indicate dust.

Returns:

(median, lo_68, hi_68) of Hα/Hβ (dimensionless ratio). For MAP results, all three values are equal.

Return type:

tuple

Raises:

ValueError – If Hα or Hβ fluxes are not available.

Notes

The Balmer decrement (Hα/Hβ flux ratio) is insensitive to stellar population age and metallicity; deviations from the intrinsic Case B value of 2.86 (Osterbrock 1989) directly indicate dust attenuation. A Balmer decrement > 3.0 typically indicates significant dust.

Examples

med, lo, hi = result.balmer_decrement()
print(f"Ha/Hb = {med:.2f} [{lo:.2f}, {hi:.2f}]")
# Intrinsic Case B = 2.86; excess indicates dust attenuation
bpt_class()[source]

Classify each posterior draw as SF / composite / AGN on the BPT-NII diagram.

Uses the standard demarcation lines:

  • Kauffmann et al. 2003 [1]_ — separates pure SF from composite (SF + AGN admixture).

  • Kewley et al. 2001 [2]_ — separates composite from pure AGN/Seyfert.

Below Kauffmann ⇒ "SF"; between Kauffmann and Kewley ⇒ "composite"; above Kewley ⇒ "AGN". Non-detections (NaN ratios) return "unknown".

Returns:

For MAP results: a single label string. For sampling results: a length-n_samples array of labels.

Return type:

str or ndarray of dtype <U9

Raises:

ValueError – If emission line fluxes are unavailable or BPT lines absent.

Notes

Kauffmann+2003 demarcation:

\[\log_{10}\!\frac{[\mathrm{O\,III}]}{H\beta} = \frac{0.61}{\log_{10}([\mathrm{N\,II}]/H\alpha) - 0.05} + 1.30 \quad (\text{for }\log [\mathrm{N\,II}]/H\alpha < 0.05)\]

Kewley+2001 demarcation:

\[\log_{10}\!\frac{[\mathrm{O\,III}]}{H\beta} = \frac{0.61}{\log_{10}([\mathrm{N\,II}]/H\alpha) - 0.47} + 1.19 \quad (\text{for }\log [\mathrm{N\,II}]/H\alpha < 0.47)\]

Points right of the asymptote (\(\log [\mathrm{N\,II}]/H\alpha \ge 0.47\)) are classified as AGN regardless of [O III]/Hβ.

References

Examples

>>> labels = result.bpt_class()
>>> import numpy as np
>>> agn_frac = float(np.mean(np.asarray(labels) == "AGN"))
bpt_nii() tuple[Array, Array][source]

BPT-NII ([NII]/Hα vs [OIII]/Hβ) diagram coordinates.

Returns log10 line ratios for each posterior sample.

Returns:

  • log_nii_ha (ndarray, shape (n_samples,) or scalar) – log10([NII]6584 / Hα). For MAP, returns scalar.

  • log_oiii_hb (ndarray, shape (n_samples,) or scalar) – log10([OIII]5007 / Hβ). For MAP, returns scalar.

Raises:

ValueError – If emission line fluxes are not available or BPT lines are missing.

Notes

The BPT diagram is a standard AGN/SF diagnostic that uses the ratios [NII]/Hα (x-axis) and [OIII]/Hβ (y-axis). Non-detections (negative or zero fluxes) are returned as NaN and will not be plotted. Diagnostic lines follow Kewley et al. (2001, 2006) conventions.

Examples

x, y = result.bpt_nii()
plt.scatter(x, y, alpha=0.3)
# Overlay diagnostic lines from starburst/Seyfert boundaries
check_convergence(verbose: bool = True) dict[source]

Check chain convergence using autocorrelation diagnostics.

Follows Behroozi (2025): chain is converged when N > 5τ for all parameters.

Parameters:

verbose (bool) – Print diagnostics table.

Returns:

Keys: 'all_converged' (bool), 'params' (dict per-parameter convergence info), 'warnings' (list of unconverged parameters).

Return type:

dict

Notes

Convergence criterion: N > 5τ_max for all parameters. If N < 5τ for any parameter, the chain is flagged as unconverged. See Behroozi (2025) for justification of the 5τ threshold.

Examples

>>> conv = result.check_convergence()
>>> if conv["all_converged"]:
...     print("Chain converged!")
... else:
...     print(f"Unconverged: {conv['warnings']}")
...     print("Run additional samples and use refine()")
property derived: dict

Derived physical quantities (stellar mass, SFR, sSFR).

For MAP: computed on the single best-fit → dict of scalars. For NUTS/geoVI: computed on all samples → dict of arrays.

Returns:

Keys: derived quantity names ("stellar_mass", "sfr_100myr", etc.). For MAP: values are scalars. For sampling: values are arrays of shape (n_samples,). Units match the forward model convention (stellar mass in [Msun], SFR in [Msun/yr]).

Return type:

dict

Notes

This is a cached property — computed on first access and cached thereafter. Requires _model to be set (populated automatically by Fitter.run()). For stochastic SFH, unresolved bursts are included in sfr_10myr and sfr_100myr outputs.

Examples

>>> derived = result.derived
>>> stellar_masses = derived["stellar_mass"]  # Shape (n_samples,)
>>> med, lo, hi = np.percentile(stellar_masses, [50, 16, 84])
diagnostics: dict
diagnostics_summary() str[source]

Print a diagnostics summary with ESS and credible intervals.

Returns:

Formatted table of per-parameter ESS and 68% credible intervals.

Return type:

str

Notes

For MAP results, returns a simple string indicating no samples are available. For sampling methods, tabulates ESS for each parameter alongside median and 68% credible intervals. Use rhat() for the split-chain Gelman-Rubin diagnostic.

Examples

>>> print(result.diagnostics_summary())
Method: mcmc_nuts
Samples: 1000
Wall time: 5.2s

Parameter Median 68% CI ESS ──────────────────────────────────────────────────────────────── stellar_mass 10.500 [10.300, 10.700] 800 …

effective_sample_size() dict[source]

Estimate effective sample size (ESS) for each parameter.

Uses Sokal’s self-consistent window method (Behroozi 2025): τ = 1 + 2 Σ ρ(k), truncated at k > 5τ. Takes the max of standard and absolute-deviation autocorrelation times for a conservative estimate.

Returns:

Keys: parameter names. Values: ESS (float). ESS = N / τ, where N is the total number of samples and τ is the integrated autocorrelation time.

Return type:

dict

Notes

ESS measures the number of independent samples. Low ESS (< 100 for typical analyses) indicates poor mixing. The threshold N > 5τ (equivalently ESS > N/5) indicates adequate sampling for most purposes.

Examples

>>> ess = result.effective_sample_size()
>>> print(f"Stellar mass ESS: {ess['stellar_mass']:.0f}")
>>> if ess["stellar_mass"] < 100:
...     print("Warning: low ESS, may need more samples")
eline_flux_cov: Array | None = None
eline_fluxes: Array | None = None
eline_names: tuple | None = None
eline_wavelengths: Array | None = None
equivalent_widths(window_aa: float = 20.0, continuum_width_aa: float = 50.0) dict[str, tuple[float, float, float]][source]

Rest-frame emission line equivalent widths from posterior samples.

For each emission line, predicts the rest-frame SED for each posterior sample (or the MAP point estimate) using the attached forward model and integrates the line flux relative to the local continuum estimated from sidebands flanking the line. Sign convention follows tengri.analysis.diagnostics.spectral.equivalent_width(): positive EW for emission, negative for absorption.

Parameters:
  • window_aa (float, optional) – Half-width of the line integration window [Angstrom]. Default 20.

  • continuum_width_aa (float, optional) – Width of each sideband used to estimate the continuum [Angstrom]. Sidebands sit at [lambda_0 +/- window +/- continuum_width]. Default 50.

Returns:

{line_name: (median_EW, lo_68, hi_68)} in [Angstrom]. For MAP results, all three values coincide.

Return type:

dict

Raises:

ValueError – If eline_fluxes / eline_wavelengths are unavailable, or if no _model is attached to compute the continuum.

Notes

The continuum is estimated locally per line from the model-predicted rest-frame SED, not from a precomputed continuum-only grid. The sideband choice therefore must avoid contamination from neighbouring emission lines; defaults are tuned for optical BPT lines.

Examples

ew = result.equivalent_widths()
ha_med, ha_lo, ha_hi = ew["Halpha"]
line_fluxes() dict[str, tuple[float, float, float]][source]

Emission line flux posterior summaries.

Returns median and 68% credible interval for each emission line.

Returns:

{line_name: (median, lo_68, hi_68)} for each emission line. Flux units match the input data [erg/s/cm²]. For MAP results, all three values are the same (single-point estimate).

Return type:

dict

Raises:

ValueError – If no emission line fluxes are available. Set eline_mode="marginalized" or "fitted" in Spectroscopy to enable.

Notes

Each line’s credible interval is computed as the 16th, 50th, and 84th percentiles of the posterior samples. For single samples (MAP), all three values coincide.

Examples

fluxes = result.line_fluxes()
ha_median, ha_lo, ha_hi = fluxes["Halpha"]
classmethod load(path: str, model=None) Posterior[source]

Load a Posterior from an HDF5 file.

Parameters:
  • path (str) – Path to HDF5 file saved by save().

  • model (SEDModel, optional) – Model reference for derived quantity computation. If provided, enables derived, plot_sed(), plot_sfh().

Returns:

Loaded posterior with all attributes restored.

Return type:

Posterior

Notes

Reads HDF5 file saved by save(). The _fitter back-reference is not restored (it is runtime-only). To use refine() or validate(), set the _fitter attribute manually or reload using model.fit(...). Provide model to enable derived quantity computation.

Examples

>>> result = Posterior.load("posterior_result.h5", model=model)
>>> print(result.summary_table())
>>> derived = result.derived  # If model is provided
log_evidence: float | None = None
loss_history: Array | None = None
method: str
params: dict
plot_corner(params=None, truths=None, figsize=None, color='C0', fig=None, axes=None, label=None)[source]

Plot corner (triangle) plot of posterior distributions.

Parameters:
  • params (list of str, optional) – Parameter names to include. Defaults to all scalar physical params. Automatically excludes psd_xi (latent field) and constant parameters.

  • truths (dict, optional) – True values to mark with dashed lines. Keys should match parameter names.

  • figsize (tuple, optional) – Figure size (width, height). Default: auto-scaled.

  • color (str) – Color for this posterior’s contours and histograms.

  • fig (matplotlib Figure, ndarray of Axes, optional) – If provided, overlay on existing corner plot (for comparing posteriors).

  • axes (matplotlib Figure, ndarray of Axes, optional) – If provided, overlay on existing corner plot (for comparing posteriors).

  • label (str, optional) – Legend label for this posterior (appears in legend on diagonal).

Returns:

fig – The corner plot figure.

Return type:

matplotlib Figure

Notes

Creates an N×N triangle plot (lower triangle only). Diagonal shows 1D marginal distributions with KDE overlay and quantile lines. Off-diagonal shows 2D histograms with credible region contours at 68% and 95%. Includes derived quantities (stellar_mass, sfr_100myr, sfr_10myr) if _model is available. Derived quantities are plotted in log10 space.

Examples

>>> fig = result.plot_corner(color="C0", label="VI")
>>> fig = result_mcmc.plot_corner(fig=fig, axes=fig.axes, color="C1", label="MCMC")
>>> plt.show()
plot_sed(n_draws=200, wave_range=(1000, 30000), ax=None)[source]

Plot posterior predictive SED with credible interval.

Draws n_draws parameter samples from the posterior, computes the rest-frame SED for each, and shades the 16th–84th percentile band around the median.

Parameters:
  • n_draws (int) – Number of posterior draws to use for the band. Ignored for MAP results (plots single SED).

  • wave_range ((float, float)) – Wavelength range in [Angstrom] to display.

  • ax (matplotlib Axes, optional) – Axes to plot on. Creates new figure if None.

Returns:

fig – The SED plot figure.

Return type:

matplotlib Figure

Notes

Plots λ F_λ (rest-frame spectral energy density) normalized at 5500 Å. For MAP results, shows the single best-fit SED. For sampling methods, draws n_draws random samples and computes percentiles of the SED over those draws. Requires _model to be available (set by model.fit()). Uses log-log axes for visibility across wavelength range.

Examples

>>> fig = result.plot_sed(n_draws=500, wave_range=(1000, 10000))
>>> plt.show()
plot_sfh(n_draws=200, ax=None)[source]

Plot posterior SFH with credible interval.

Parameters:
  • n_draws (int) – Number of posterior draws. Ignored for MAP (plots single SFH).

  • ax (matplotlib Axes, optional) – Axes to plot on. Creates new figure if None.

Returns:

fig – The SFH plot figure.

Return type:

matplotlib Figure

Notes

Plots both the stochastic SFH (burst component) and smooth component. For MAP: shows single best-fit SFH (both stochastic and smooth). For sampling: draws n_draws random samples and computes 16th–84th percentile bands. X-axis: lookback time [Gyr]. Y-axis: SFR [Msun/yr]. Requires _model to be available (set by model.fit()).

Examples

>>> fig = result.plot_sfh(n_draws=500)
>>> plt.show()
posterior_predictive(data: Array, noise: Array, n_samples: int | None = None, key=None) dict[str, Array][source]

Posterior predictive predictions, residuals, and chi^2 distribution.

Pushes posterior draws (or the MAP point estimate) through the attached forward model’s predict_photometry and reports per-draw predictions, standardised residuals, and a chi^2 distribution against the supplied data + noise.

Parameters:
  • data (array_like, shape (n_obs,)) – Observed data the posterior was conditioned on. Same units as predict_photometry’s output.

  • noise (array_like, shape (n_obs,)) – Per-observation 1-sigma uncertainty (Gaussian).

  • n_samples (int, optional) – How many posterior draws to evaluate. None (default) uses every available draw; for MAP results this is implicitly 1. For sampling results, draws are selected via resample() (with replacement) using key.

  • key (PRNGKey, optional) – JAX PRNG key for resampling. If None, defaults to jax.random.PRNGKey(0).

Returns:

predictions: shape (N, n_obs), residuals: (data - prediction) / noise of shape (N, n_obs), chi2: per-draw \(\chi^2 = \sum_i r_i^2\), shape (N,), chi2_median, chi2_lo, chi2_hi: 16/50/84 percentiles of the chi^2 distribution (scalars). N is 1 for MAP, otherwise n_samples (or the full chain length if n_samples is None).

Return type:

dict

Raises:

ValueError – If no _model is attached.

Notes

This is a deterministic posterior predictive (no extra noise realisation per draw). For replicated PPCs that draw observation noise per sample, layer noise * jax.random.normal on top of predictions.

refine(method: str, **kwargs)[source]

Re-run inference from this result using a different method.

Requires that this Posterior was produced by model.fit() or fitter.run() — both set the ._fitter back-reference.

Parameters:
  • method (str) – Any canonical method name accepted by Fitter.run(). E.g. "mcmc_raytrace", "mcmc_nuts", "vi".

  • **kwargs – Passed to Fitter.run() (e.g. n_steps, n_warmup).

Returns:

New result warm-started from this posterior.

Return type:

Posterior

Raises:

RuntimeError – If ._fitter is not set (Posterior created outside model.fit/fitter.run).

Notes

Warm-starts the new inference from this posterior’s parameters or samples. Common use cases: VI → MCMC refinement (exact inference on top of variational fit), MCMC → different sampler (e.g. raytrace → nuts), or quick method → expensive method for publication.

Examples

>>> result_vi = model.fit(flux, noise)
>>> result_exact = result_vi.refine("mcmc_raytrace", n_steps=1000)
resample(key, n=1) dict[source]

Resample from posterior with replacement.

Parameters:
  • key (PRNGKey) – JAX random key.

  • n (int) – Number of resamples.

Returns:

If n=1: parameter name → scalar value. If n>1: parameter name → array of shape (n, …).

Return type:

dict

Notes

For MAP results, returns the point estimate (repeated n times if n > 1). For sampling results, draws n indices uniformly from [0, n_samples) with replacement. Use for Monte Carlo propagation of posterior uncertainty through forward models.

Examples

>>> key = jax.random.PRNGKey(0)
>>> sample = result.resample(key, n=1)  # Single draw
>>> samples = result.resample(key, n=100)  # 100 resamples
>>> sfhs = [model.predict_sfh(samples) for ...]  # Propagate
rhat(exclude_prefixes: tuple[str, ...] = ('psd_xi',)) dict[str, float][source]

Per-parameter split-\(\hat R\) (Gelman-Rubin).

Splits each parameter chain in half and computes the classical \(\hat R = \sqrt{\hat V / W}\) against the two halves. \(\hat R \approx 1.0\) indicates convergence; \(\hat R > 1.01\) (Vehtari+2021) suggests failure to mix.

Parameters:

exclude_prefixes (tuple of str, optional) – Parameter name prefixes to skip. Default skips psd_xi (GP latent vector — high-D, not informative per-component).

Returns:

Parameter name → \(\hat R\). Static (zero-variance) and excluded parameters are dropped.

Return type:

dict

Raises:

ValueError – If this is a MAP result (no samples to split).

See also

tengri.analysis.diagnostics.autocorrelation.split_rhat

Underlying implementation.

Examples

>>> rh = result.rhat()
>>> bad = {k: v for k, v in rh.items() if v > 1.05}
>>> if bad:
...     print(f"Unconverged: {bad}")
samples: dict | None
save(path: str) None[source]

Save posterior to HDF5 file.

Serializes samples, params, diagnostics, loss history, and emission line data. Model and fitter references are NOT saved (they are non-serializable runtime objects).

Parameters:

path (str) – Output HDF5 file path.

Return type:

None

Notes

Saves to HDF5 format with groups: - samples: posterior samples (if available) - params: best-fit or MAP parameters - loss_history: optimization loss over iterations (if available) - diagnostics: method-specific convergence metrics - eline: emission line fluxes, covariances, names, wavelengths Use load() to restore the Posterior from disk.

Examples

>>> result.save("posterior_result.h5")
>>> later_result = Posterior.load("posterior_result.h5")
sed_components(wavelength=None) dict[source]

Per-component SEDs across posterior draws.

Decomposes the model prediction into stellar (attenuated / intrinsic) + nebular + shock + dust IR + AGN + radio + X-ray for every posterior sample (or the MAP point estimate) by running the orchestrator pipeline per draw and reading the per-component SED arrays each adapter publishes into state.derived.

Parameters:

wavelength (array_like, optional) – Ignored — kept for backwards compatibility. The orchestrator uses the model’s SSP wavelength grid; pass-through to a different grid is no longer supported here. Interpolate the returned arrays externally if a different grid is needed.

Returns:

Keys: wavelength plus the entries of Posterior._COMPONENT_KEYS (sed_total, sed_attenuated, sed_intrinsic, sed_nebular, sed_shock, sed_dust_ir, sed_agn, sed_radio, sed_xray). Each component array has shape (n_wave,) for MAP and (n_samples, n_wave) for sampling.

Return type:

dict

Raises:

ValueError – If no _model is attached.

Notes

Reads the orchestrator’s per-component publications: sed_dust_attenuated (stellar post-attenuation), sed_dust_ir, sed_nebular, sed_shock, sed_agn, sed_radio, sed_xray — each adapter publishes its own contribution into state.derived. sed_total is the accumulated state.sed_intrinsic after the chain runs; sed_intrinsic (stellar pre-attenuation) is reconstructed from the published lnu_age cube via sum(lnu_age, axis=0).

For Posterior draws this is still a Python loop (one orchestrator pass per sample); JIT vectorisation across draws is a separate optimisation.

Examples

>>> comp = result.sed_components()
>>> agn_to_total = comp["sed_agn"] / comp["sed_total"]
stats() dict[source]

Median and 68% credible intervals for all parameters.

Returns:

Keys: parameter names (excluding "psd_xi" latent field). Values: dict with "median", "lo_68", "hi_68" for sampling methods, or "value" for MAP.

Return type:

dict

Notes

For MCMC and VI results, credible intervals are 16th and 84th percentiles. For MAP results, returns point estimates without intervals. Does not include high-dimensional latent fields (psd_xi).

Examples

>>> stats = result.stats()
>>> print(stats["stellar_mass"])
{"median": 10.5, "lo_68": 10.3, "hi_68": 10.7}  # sampling
# or
{"value": 10.5}  # MAP
summary() None[source]

Print the per-parameter median ± 68% credible interval table.

Convenience wrapper around summary_table() that prints directly — typically the first call after a fit:

posterior = fitter.run("nuts")
posterior.summary()
summary_str() str[source]

Return the summary as a string (parity with Parameters.summary_str).

Same content as summary_table() — alias for naming consistency across the discovery API (tengri.Parameters.summary_str returns the same kind of string).

summary_table() str[source]

Return a formatted string table of parameter summaries.

For MAP: shows parameter values. For sampling: shows median with 68% credible intervals and ESS.

Returns:

Formatted table string with method, sample count, wall time, parameter statistics, and diagnostics.

Return type:

str

Notes

The table includes: - Method name and number of samples (or "MAP") - Wall-clock time in seconds - Parameter names with median and credible intervals - Effective sample size (ESS) for sampling methods - Method-specific diagnostics (accept rate, divergences, loss, etc.) - Log evidence (if available from nested sampling)

Examples

>>> print(result.summary_table())
Posterior  method: mcmc_nuts  samples: 1000  wall_time: 5.2s
─────────────────────────────────────────────────────────────
  Parameter                   Median        16%        84%     ESS
  ─────────────────────────────────────────────────────────────
  ...
to_arviz()[source]

Convert to ArviZ InferenceData for diagnostics.

Returns:

ArviZ InferenceData object with posterior group containing all scalar parameters.

Return type:

az.InferenceData

Notes

Requires arviz to be installed: pip install arviz. High-dimensional latent fields (psd_xi) are excluded. Samples are reshaped to (1, n_samples) format (1 chain). Use ArviZ tools for advanced visualization and diagnostics (forest plots, rank plots, etc.).

Examples

>>> idata = result.to_arviz()
>>> az.plot_forest(idata)
>>> az.summary(idata)
to_param_spec()[source]

Convert posterior to an empirical Parameters.

For MAP: all parameters become Fixed at their best-fit values. For sampling: fit clipped Gaussian to each marginal.

Returns:

New Parameters object with priors fit to the posterior.

Return type:

Parameters

Notes

For MAP results, all parameters become Fixed at the MAP value. For sampling methods, each parameter gets a Gaussian prior with: - mean: median of samples - sigma: standard deviation of samples - bounds: [min, max] from samples (clipping) Inherits stochastic and n_grid settings from the original model.

Examples

>>> posterior_params = result.to_param_spec()
>>> # Use as starting point for next fit
>>> refined = fitter.run("mcmc_nuts", init_from=posterior_params)
validate(n_steps: int = 200, **kwargs)[source]

Run a short MCMC check and return a validation summary.

Runs n_steps of Ray Tracing (or NUTS for D≤20) from this posterior’s MAP estimate, then computes the marginal overlap between this posterior and the MCMC check posterior for each parameter.

Parameters:
  • n_steps (int) – Number of MCMC steps. Default 200 (quick sanity check).

  • **kwargs – Forwarded to the MCMC run.

Returns:

Keys: "mcmc_result" (Posterior from MCMC check), "overlap" (dict of float per parameter, 1.0 = perfect overlap, 0.0 = no overlap), "passed" (bool, True when all overlaps > 0.5).

Return type:

dict

Raises:

RuntimeError – If ._fitter is not set.

Notes

Validation checks whether a quick MCMC run agrees with the current posterior (typically from VI or MAP). High overlap (> 0.5) indicates the method is reliable; low overlap suggests the posterior may be biased or misspecified. Overlap is computed as the histogram intersection at each parameter. For sampling methods (VI, geoVI), validates the approximate posterior. For MCMC methods, serves as a sanity check for chain convergence.

Examples

>>> result_vi = model.fit(flux, noise, method="vi")
>>> val = result_vi.validate(n_steps=500)
>>> print(f"Validation passed: {val['passed']}")
>>> for param, ov in val["overlap"].items():
...     print(f"{param}: overlap={ov:.3f}")
wall_time_s: float

PopulationFitter

class tengri.PopulationFitter(model_factory, galaxies, psd_sigma_prior=(0.1, 4.0), psd_tau_prior=(1.0, 300.0), data_type='photometry')[source]

Bases: object

Hierarchical inference for shared PSD parameters.

Manages population-level inference via hierarchical VI, learning the shared PSD hyperparameters (σ_PSD, τ_PSD) across a population of galaxies while preserving per-galaxy latent fields and physical parameters.

Parameters:
  • model_factory (callable) – Function(psd_sigma, psd_tau_myr) → Model. Creates a model with the given PSD params. All other params (SFH, dust, etc.) come from the model’s Parameters.

  • galaxies (list of dict) – Each dict has ‘flux_obs’, ‘noise’, and optionally ‘spec_obs’, ‘spec_noise’, ‘wave_spec’.

  • psd_sigma_prior (tuple) – (lo, hi) for uniform prior on σ_PSD.

  • psd_tau_prior (tuple) – (lo, hi) for uniform prior on τ_PSD (Myr).

  • data_type (str) – “photometry” or “spectroscopy”.

Notes

Wraps all hierarchical inference methods (EVI, VI, MCMC) with automatic initialization via per-galaxy MAP estimation. The class builds a single flat parameter vector [φ_shared, ξ_1, θ_1, …, ξ_N, θ_N] and optimizes it via the specified method.

n_galaxies

Number of galaxies in the population.

Type:

int

Examples

>>> import jax
>>> from tengri import PopulationFitter, SEDModel, Parameters, Uniform
>>> # Define a factory that builds a model given shared PSD params
>>> def model_factory(psd_sigma, psd_tau_myr):
...     spec = Parameters(
...         sfh_field_psd_sigma=Uniform(0.1, 4.0),
...         sfh_field_psd_tau_myr=Uniform(1.0, 300.0),
...     )
...     return SEDModel(spec, ssp_data)  # ssp_data loaded separately
>>> # galaxies = [{'flux_obs': ..., 'noise': ...}, ...]
>>> # pop = PopulationFitter(model_factory, galaxies)
>>> # result = pop.run('vi', key=jax.random.PRNGKey(0))
run(method='native_vi_linear', *, key=None, **kwargs)[source]

Run hierarchical inference.

Parameters:
  • method (str) –

    Pure-JAX (lax.while_loop, no NIFTy; recommended)

    • "native_vi_linear" — MGVI inside lax.while_loop (default). 3–4× faster than NIFTy MGVI on CPU; O(1) memory in N.

    • "native_vi_nonlinear" — geoVI inside lax.while_loop. Comparable speed to NIFTy geoVI; prefers lower N (≤20).

    NIFTy-backed (CorrelatedFieldMaker, native PSD learning)

    • "vi_nonlinear_fast" — geoVI via NIFTy optimize_kl.

    • "vi_nonlinear" — geoVI; same runner as fast, kept for API symmetry.

    • "vi_linear_fast" — MGVI via NIFTy optimize_kl.

    • "vi_linear" — MGVI; same runner as fast, kept for API symmetry.

    MCMC

    • "mcmc_raytrace" — Ray Tracing on flat vector.

    • "mcmc_ess" — Ensemble sampling (alias for native_vi_linear).

  • key (PRNGKey, optional) – Random key for reproducibility. If None, uses PRNGKey(0).

  • **kwargs – Passed to the inference method.

Returns:

Results object with shared_params, shared_samples, individual_samples, and diagnostics.

Return type:

PopulationPosterior

Notes

Unlike Fitter.run(), PopulationFitter.run() does not support warm-start initialization via init_from because hierarchical inference methods (EVI, CFM-based geoVI, Ray Tracing) initialize per-galaxy parameters via MAP estimation automatically. The init_from parameter is not meaningful in the hierarchical context.

Automatic initialization: all methods initialize per-galaxy parameters via MAP estimation before starting the hierarchical inference. First call may be slow due to JIT compilation. Subsequent calls are fast. Approximate runtime: ~30 seconds for 10 galaxies on CPU (method-dependent).

Compile-cost amortization across catalog sizes

Unlike CatalogFitter, PopulationFitter does not expose an n_pad argument: the hierarchical population field couples all N galaxies, so padding with dummy galaxies would contribute spurious prior mass to the population hyperparameters (e.g. psd_sigma, psd_tau_myr) even with masked likelihoods. To amortize XLA compile cost across notebook restarts, slurm tasks, or sweeps over different catalog sizes, rely on the persistent compilation cache instead — see tengri.enable_persistent_cache() and docs/inference/compilation_cache.md.

PopulationPosterior

class tengri.PopulationPosterior(shared_samples: dict, shared_params: dict, individual_samples: list | None = None, method: str = '', wall_time_s: float = 0.0, diagnostics: dict = <factory>)[source]

Bases: object

Results from hierarchical PSD inference.

This class holds the posterior distribution over shared PSD hyperparameters (σ_PSD, τ_PSD) inferred across a population of galaxies, along with optional per-galaxy individual posteriors.

Parameters:
  • shared_samples (dict) – Posterior samples for shared PSD params. Keys are param names (e.g., ‘psd_sigma’, ‘psd_tau_myr’), values are arrays of shape (n_samples,).

  • shared_params (dict) – Posterior mean of shared PSD params (computed from shared_samples).

  • individual_samples (list of dict, optional) – Per-galaxy posterior samples. Each element is a dict with per-galaxy parameter names as keys. If None, individual posteriors are not stored (for memory efficiency).

  • method (str) – Inference method used (e.g., “Hierarchical EVI (JIT)”).

  • wall_time_s (float) – Total wall-clock time for inference.

  • diagnostics (dict) – Method-specific diagnostics (e.g., number of iterations, convergence info).

Notes

This dataclass is the return type for PopulationFitter.run(). Access population posteriors via shared_samples and shared_params, and per-galaxy posteriors via the individual property (returns a list of lightweight objects with .samples and .params attributes).

Examples

>>> result = fitter.run("vi", n_iterations=20)
>>> result.summary()  # Median and 68% credible intervals
>>> ax = result.plot_population(params=("sfh_field_psd_sigma", "sfh_field_psd_tau_myr"))
diagnostics: dict
property individual

Per-galaxy posterior marginals as a list of lightweight objects.

Parameters:

None

Returns:

Each element has .samples (dict) and .params (dict). Returns empty list if individual_samples is None.

Return type:

list of SimpleNamespace

Notes

Each per-galaxy posterior is marginalized over the shared PSD hyperparameters. The .params field contains the median of each per-galaxy parameter.

individual_samples: list | None = None
method: str = ''
plot_population(params=('sfh_field_psd_sigma', 'sfh_field_psd_tau_myr'), ax=None)[source]

Scatter plot of shared PSD parameter posteriors.

Parameters:
  • params (tuple of str) – Two parameter names for x and y axes.

  • ax (matplotlib Axes, optional) – If None, creates a new figure.

Returns:

The axes object with the scatter plot.

Return type:

matplotlib Axes

Notes

Plots posterior samples as a scatter cloud. Each point is one posterior sample.

population_diagnostics(exclude_prefixes: tuple[str, ...] = ('psd_xi',)) dict[source]

Convergence diagnostics for the population fit.

Computes split-Rhat and effective sample size for the shared PSD block, and (when individual_samples is present) aggregates per-galaxy diagnostics into population-level summaries.

Parameters:

exclude_prefixes (tuple of str, optional) – Parameter-name prefixes to skip when computing per-galaxy diagnostics. Default ("psd_xi",) skips the GP latent fields, which carry one chain entry per grid point and inflate dict size without adding interpretable information.

Returns:

Two-level structure:

{
    "shared": {
        param_name: {"rhat": float, "ess": float},
        ...
    },
    "per_galaxy": {                # only if individual_samples is set
        param_name: {
            "rhat_p50": float,    # median across galaxies
            "rhat_p90": float,
            "rhat_max": float,
            "ess_p50": float,
            "ess_min": float,
            "n_galaxies": int,
        },
        ...
    },
}

Use "per_galaxy" to spot a single galaxy whose chain has stalled — a high rhat_max with low rhat_p50 is the signature.

Return type:

dict

Notes

Reuses tengri.analysis.diagnostics.rhat() and tengri.analysis.diagnostics.effective_sample_size(). Static (zero-variance) parameters are dropped silently.

Examples

>>> result = fitter.run("vi", n_iterations=20)
>>> diag = result.population_diagnostics()
>>> diag["shared"]["sfh_field_psd_sigma"]
{'rhat': 1.012, 'ess': 318.4}
shared_params: dict
shared_samples: dict
summary() dict[source]

Median and 68% CI for shared PSD parameters.

Parameters:

None

Returns:

Dictionary mapping parameter names to summary statistics. Each parameter has keys ‘median’, ‘lo_68’ (16th percentile), and ‘hi_68’ (84th percentile).

Return type:

dict

wall_time_s: float = 0.0

sample_raytrace

tengri.sample_raytrace(key, params_init, log_prob_fn, n_steps, n_leapfrog_steps, step_size, refresh_rate=0, metro_check=1, sample_hmc=False, integrator='dkd')[source]

Run Ray Tracing Sampler and return the full Markov chain.

Parameters:
  • key (PRNGKey) – Random key.

  • params_init (array) – Initial parameter values (flat 1D array).

  • log_prob_fn (callable) – Function mapping params → scalar log probability.

  • n_steps (int) – Number of MCMC steps (samples to collect).

  • n_leapfrog_steps (int) – Leapfrog integration steps per trajectory.

  • step_size (float) – Leapfrog step size. Recommended: ~0.03 * sqrt(D).

  • refresh_rate (float) – Partial momentum refresh rate. 0 = no refresh.

  • metro_check (int) – 1 = apply Metropolis correction, 0 = skip.

  • sample_hmc (bool) – If True, use HMC instead of ray tracing.

  • integrator (str) – Leapfrog integrator scheme: "dkd" (Drift-Kick-Drift, default) or "kdk" (Kick-Drift-Kick). KDK is the time-reversed partner of DKD. Both are symplectic and second-order. DKD matches Behroozi’s reference implementation.

Returns:

  • chain (array, shape (n_steps, D)) – Parameter samples.

  • log_likelihood (array, shape (n_steps,)) – Log-likelihood at each accepted sample.

  • accept_prob (array, shape (n_steps,)) – Acceptance probability at each step.

Notes

JIT-compatible: no — wraps blackjax sampler with Python-level loop.

Gradient-safe: no — MCMC sampler, not a differentiable operation.

Examples

>>> import jax
>>> import jax.numpy as jnp
>>> from tengri import sample_raytrace
>>> key = jax.random.PRNGKey(0)
>>> log_prob_fn = lambda x: -0.5 * jnp.sum(x**2)
>>> chain, lnl, acc = sample_raytrace(
...     key, jnp.zeros(5), log_prob_fn, n_steps=10, n_leapfrog_steps=5, step_size=0.1
... )
>>> chain.shape
(10, 5)