Core API

The core classes that form tengri’s high-level interface: defining models, specifying parameters, generating predictions, and creating mock data.

SEDModel

class tengri.SEDModel(spec, ssp_data, filters=None, observation=None, precompute=True, forward_dtype='float64', approx=None, csp_integration='trapz', wave_chunk_size=None, agn_config=None)[source]

Bases: object

Differentiable SED forward model with modular physics and clean API.

The forward model maps physical parameters (stellar mass, SFH, metallicity, dust, AGN, etc.) to observables: photometry, spectrum, and derived SED quantities. Internally, it decomposes the SED pipeline into independent physics modules (stellar populations, star formation history, dust, nebular, AGN, IGM) that are composed into prediction kernels at initialization time, enabling fast inference and flexibility in model configuration.

The SFH is computed via a registry-driven composed function that handles additive smooth models, burst mixture, and correlated-field (GP) modulation in a single call. Three prediction modes (compositional, hybrid, exact) trade accuracy for speed, with automatic fallback.

Parameters:
  • spec (Parameters) – Parameter specification from tengri.Parameters. Defines free/fixed parameters and their priors.

  • ssp_data (SSPData) – Pre-loaded SSP templates (from load_ssp_data()). Contains absolute SSP grid in log10(Z) absolute, age array, and optional mass-remaining tables for stellar mass surviving constraints.

  • filters (list or tuple, optional) –

    Filter transmission curves for photometric prediction. Accepts either:

    • 3-tuple from load_filter_set(): (filter_waves, filter_trans, filter_curves)

    • List of FilterCurve namedtuples

    If provided, enables photometry prediction and automatic precomputation at initialization. Either filters or observation may be passed, not both.

  • observation (Observation, optional) – Unified observation config (photometry + spectroscopy + emission lines). Mutually exclusive with filters.

  • precompute (bool, optional) – Whether to precompute SSP photometry and spectroscopy grids at initialization. Default True activates the Zacharegkas+2025 fast-photometry path and enabling caching of spectroscopy grids. Set False to defer computation (useful for batch operations).

  • forward_dtype (str or jnp.dtype, optional) –

    Dtype for forward model computation. Default "float64" preserves full precision. "float32" halves memory and gives ~1.5× speedup with <0.1% accuracy loss for photometry.

    Affects both fused (photometry + precomputation) and exact paths:

    • Fused path: captured arrays (SSP grid, dust weights, effective wavelengths) cast to forward_dtype at kernel build; outputs always cast back to float64 for cosmological distance scaling.

    • Exact path (spectroscopy, non-precomputed AGN): three largest intermediates — metallicity-interpolated SSP (n_age, n_wave), dust attenuation (n_age, n_wave), dust age weights (n_age,) — computed in forward_dtype, halving the 4.5 MB memory traffic that dominates exact-path dust cost.

    Cosmological distances always use float64 (float32 overflows at z > 0.01).

  • approx (dict or bool, optional) –

    Control which approximations the fused kernel uses. Default True enables all approximations (fastest). False disables all (forces exact path everywhere). A dict enables selective control:

    • "dust_attenuation": use dust at filter effective wavelengths (True, default)

    • "dust_emission": use MBB at filter effective wavelengths (True, default)

    • "igm": use IGM at filter effective wavelengths (True, default)

    Approximation accuracy (Zacharegkas+2025 [1]_):

    • dust_attenuation: <3% for most laws, ~36% for SMC

    • dust_emission: negligible for optical (MBB peak >50 μm)

    • igm: exact for fixed z (precomputed once)

  • csp_integration (str, optional) – CSP age integration scheme. Default "trapz" (trapezoidal on linear time). Options: "log_trapz", "log_interp" (Dopita+2005 interpolation), "dsps_native" (DSPS trapezoidal with automatic metallicity marginalization), "dsps_met_table" (time-evolving metallicity table). See Appendix A of the forward model paper [2]_.

observation

Attached observation object containing photometry and/or spectroscopy configuration. Set by constructor if filters or observation= passed.

Type:

Observation or None

spec

Parameter specification defining all free/fixed parameters and their priors.

Type:

Parameters

ssp_data

Pre-loaded stellar population synthesis templates (from load_ssp_data()).

Type:

SSPData

config

Frozen model configuration (immutable after init).

Type:

ModelConfig

Notes

JIT-compatible: yes — all prediction methods (except predict() for lazy evaluation) are fully JAX differentiable and can be called inside jax.jit() and jax.vmap().

Gradient-safe: yes — all physical parameters are differentiable for inference via HMC, VI, and score-based methods.

Approximation scheme: The forward model uses a three-tier kernel hierarchy to balance speed and accuracy:

  1. Compositional (preferred): Full-resolution JIT SED from all components → filter integration. XLA fuses entire graph (SFH → SED → photometry). Bit-exact and fastest.

  2. Hybrid (fallback): Precomputed SSP×filter stellar + exact non-stellar at full wavelength resolution.

  3. Exact (reference): Raw pipeline, no approximations or precomputation.

Mode selection in predict_photometry() and predict_spectrum(): mode="auto" (default) cascades through available modes.

Physical units (internal):

  • Time: years (yr). User-facing API converts to Myr/Gyr.

  • Wavelength: Angstrom (Å).

  • Luminosity (SED components): erg/s/Hz (L_ν).

  • Luminosity (photometry): erg/s/cm²/Hz (f_ν).

  • Metallicity (SSP grid): log₁₀(Z) absolute. User API uses log₁₀(Z/Z☉).

  • AGN bolometric luminosity: log₁₀(L_bol/L☉) at API level.

IGM absorption gotcha: predict_obs_sed() applies IGM transmission at observed-frame wavelengths (input to igm_transmission() is redshifted). This is automatic when igm=True in spec.

References

Examples

Standard photometric fit with DPL SFH:

from tengri import SEDModel, Parameters, Uniform, load_ssp_data, Photometry

ssp = load_ssp_data("data/ssp_miles.h5")
phot = Photometry.from_names(["sdss_r", "sdss_i", "sdss_z"])
spec = Parameters(
    redshift=0.1,
    sfh_dpl_alpha=Uniform(0.5, 4.0),
    sfh_dpl_beta=Uniform(0.3, 3.0),
)
model = SEDModel(spec, ssp, observation=phot)
compile_signature() tuple[source]

Return a hashable signature identifying JIT-graph shape and structure.

Two SEDModel instances with the same compile_signature() will produce identical XLA compilation graphs (for identical Fitter configurations), enabling cross-galaxy engine reuse in PopulationFitter and CatalogFitter.

The signature captures every JIT-affecting field: SSP array shapes, filter grid dimensions, dust/AGN/nebular model identities, and all configuration flags that determine the control flow during inference.

Returns:

Hashable immutable signature. Entries are immutable types (int, str, tuple, bool, None) or tuples thereof.

Return type:

tuple

Notes

This signature is used by Fitter._get_or_build_engine to key the module-level _SHARED_ENGINE_CACHE. Changes to SEDModel initialization that affect JIT graph shape MUST be added to this method to avoid silent miscompilation.

fit(data=None, noise=None, method: str = 'vi', data_type: str | None = None, *, photometry: tuple | None = None, spectrum: tuple | None = None, init: str | None = None, **kwargs)[source]

Fit observed data. Convenience wrapper — no Fitter construction needed.

Parameters:
  • data (array, optional) – Observed flux array (photometry or spectroscopy). For joint fitting, leave as None and use photometry= / spectrum= instead.

  • noise (array, optional) – 1-sigma uncertainties matching data.

  • method (str) – Inference method. Default "vi" (geoVI variational inference). Any canonical name accepted by Fitter.run() works here: "vi", "vi_linear", "mcmc", "mcmc_raytrace", "mcmc_nuts", "map", "laplace", "auto", etc.

  • data_type (str or None) – "photometry", "spectroscopy", or "joint". When None (default), inferred from the model’s observation or from whether photometry= / spectrum= kwargs are used.

  • photometry (tuple of (flux, noise), optional) – Photometric data for joint fitting. Pass alongside spectrum=.

  • spectrum (tuple of (flux, noise), optional) – Spectroscopic data for joint fitting. Pass alongside photometry=.

  • init (str or None) – Initialization strategy. "map" runs MAP optimization first, then uses the result to warm-start the requested method. None (default) uses the method’s own default initialization.

  • **kwargs – Forwarded to Fitter.run().

Returns:

Inference results. ._fitter is set so .refine() works. After this call, self.fitter_ holds the Fitter instance.

Return type:

Posterior

Notes

Convenience wrapper around Fitter. For advanced usage (custom loss, multiple refinement steps), use Fitter directly.

Examples

>>> result = model.fit(flux_obs, noise)
>>> result = model.fit(flux_obs, noise, method="mcmc")
>>> result = model.fit(photometry=(flux_p, noise_p), spectrum=(flux_s, noise_s))
>>> result = model.fit(flux_obs, noise, init="map")
>>> result = model.fit(flux_obs, noise).refine("mcmc_raytrace")
fit_batch(catalog, flux_cols: list[str], err_cols: list[str], redshift_col: str | None = None, method: str = 'vi', n_workers: int = 1, verbose: bool = True, output_dir: str | None = None, id_col: str | None = None, **kwargs) list[source]

Fit a batch of galaxies from a catalog (DataFrame, Table, or list of dicts).

Parameters:
  • catalog (DataFrame, Table, or list of dict) – Input catalog.

  • flux_cols (list of str) – Column names for per-band flux values.

  • err_cols (list of str) – Column names for per-band 1-sigma uncertainties.

  • redshift_col (str or None) – If provided, use this column as per-row redshift.

  • method (str) – Inference method. Default "vi".

  • n_workers (int) – Currently ignored (reserved for multiprocessing). Default 1.

  • verbose (bool) – Print per-galaxy progress. Default True.

  • output_dir (str or None) – If provided, save each Posterior to {output_dir}/{id}.h5.

  • id_col (str or None) – Column name for galaxy identifiers in checkpoint filenames.

  • **kwargs – Forwarded to Fitter.run().

Returns:

One result per galaxy in catalog.

Return type:

list of Posterior

Notes

Sequential fitting (no parallelization yet). For 1000+ galaxies, consider using fit() in a loop with a multiprocessing pool.

Examples

>>> import pandas as pd
>>> cat = pd.read_csv("catalog.csv")
>>> results = model.fit_batch(
...     cat,
...     flux_cols=["f_u", "f_g", "f_r", "f_i", "f_z"],
...     err_cols=["e_u", "e_g", "e_r", "e_i", "e_z"],
...     redshift_col="z",
... )
fit_population(observations_list: list, method: str = 'vi', population_prior: dict | None = None, **kwargs)[source]

Fit a population of galaxies with shared PSD hyperparameters.

Parameters:
  • observations_list (list) – Each element is a (flux, noise) tuple or dict with flux_obs/noise keys.

  • method (str) – Hierarchical inference method. Default "vi".

  • population_prior (dict or None) – Hyperpriors on shared PSD parameters.

  • **kwargs – Forwarded to PopulationFitter.run().

Returns:

Hierarchical inference results with population-level and per-galaxy posteriors.

Return type:

PopulationPosterior

Notes

Enables population-level constraints on shared PSD hyperparameters (e.g., shared burst timescale across a sample). All galaxies must use the same model configuration.

Examples

>>> obs_list = [(flux1, noise1), (flux2, noise2), ...]
>>> result = model.fit_population(obs_list, method="vi")
classmethod from_config(ssp, sfh=Ellipsis, dust=Ellipsis, nebular=Ellipsis, agn=Ellipsis, redshift=Ellipsis, filters: list[str] | None = None, wave_obs=None, priors: dict | None = None, **model_kwargs) SEDModel[source]

Build a SEDModel from a grouped configuration dict.

Reduces boilerplate for the common case: instead of constructing Parameters, SSPData, Observation, and SEDModel separately, provide a single grouped config and receive a fully configured SEDModel.

Parameters:
  • ssp (str or SSPData) – Path to SSP HDF5 file, or a pre-loaded SSPData instance.

  • sfh (str) – SFH family name, e.g. "tsnorm", "dpl", "dpl+field".

  • dust (str) – Dust attenuation law. "charlot_fall" (default), "calzetti", etc.

  • nebular (str or None) – Nebular emission backend. "baked_in", "cloudy", "cue", or None.

  • agn (str or None) – AGN model. None (disabled) or any AGN model name.

  • redshift (float or str) – Fixed redshift (float), or "free" to add a free redshift parameter.

  • filters (list of str, optional) – Filter names for photometry, e.g. ["sdss_u", "sdss_g", "sdss_r"].

  • wave_obs (array, optional) – Observed-frame wavelength array for spectroscopy.

  • priors (dict, optional) – Parameter priors. Keys may be short names ("log_peak_sfr"), universal short names ("logzsol"), or full prefixed names. Short names are expanded automatically.

  • **model_kwargs – Forwarded to SEDModel.__init__().

Returns:

Fully initialized model ready for prediction or fitting.

Return type:

SEDModel

Notes

Ellipsis (...) placeholders in optional parameters map to defaults from defaults.toml. For example, dust=... uses the default dust attenuation law.

Examples

>>> model = tengri.SEDModel.from_config(
...     ssp="data/ssp.h5",
...     sfh="dense_basis",
...     filters=["sdss_u", "sdss_g", "sdss_r"],
...     redshift=0.1,
...     priors=dict(
...         log_total_mass=tengri.Uniform(8, 12),
...         log_sfr_inst=tengri.Uniform(-2, 3),
...         logzsol=tengri.Uniform(-2, 0.2),
...     ),
... )
mock(params, snr=20.0, key=None)[source]

Generate mock photometric observation with noise.

Parameters:
  • params (dict) – Parameter values.

  • snr (float) – Signal-to-noise ratio. Default 20.0.

  • key (PRNGKey, optional) – Random key for noise. If None, returns noiseless.

Returns:

Mock photometric observation.

Return type:

MockData

Notes

Requires model to have filters configured (filters= or observation= in constructor).

Examples

>>> key = jax.random.PRNGKey(0)
>>> mock = model.mock(params, snr=15.0, key=key)
>>> print(mock.flux.shape)  # (n_filters,)
mock_batch(params_batch, snr=20.0, key=None)[source]

Generate batch of mock photometric observations.

Parameters:
  • params_batch (dict of arrays) – Each value has leading batch dimension.

  • snr (float) – Signal-to-noise ratio. Default 20.0.

  • key (PRNGKey, optional) – Random key for noise. If None, returns noiseless.

Returns:

Mock observations with shape (N, n_filters).

Return type:

MockData

Notes

Uses jax.vmap() over mock() for vectorized generation.

Examples

>>> params_batch = {
...     k: jnp.tile(v[None], (1000,) + (1,) * (len(v.shape)))
...     for k, v in posterior.samples.items()
... }
>>> mocks = model.mock_batch(params_batch, snr=15.0, key=key)
mock_spectrum(params, wave_obs, snr=30.0, key=None)[source]

Generate mock spectroscopic observation with noise.

Parameters:
  • params (dict) – Parameter values.

  • wave_obs (array) – Observed wavelength grid [Angstrom].

  • snr (float) – Signal-to-noise ratio per pixel. Default 30.0.

  • key (PRNGKey, optional) – Random key for noise. If None, returns noiseless.

Returns:

Mock spectroscopic observation.

Return type:

MockData

Notes

Noise is drawn from Gaussian distribution with standard deviation = flux/snr.

Examples

>>> wave_obs = np.linspace(4000, 5500, 1000)
>>> mock = model.mock_spectrum(params, wave_obs, snr=10.0, key=key)
>>> print(mock.flux.shape)  # (1000,)
plot_sfh_posterior(posterior, true_params=None, ax=None, n_draws=50, color='C0', label='Posterior')[source]

Plot posterior SFH with percentile fill and sample lines.

Parameters:
  • posterior (Posterior) – Inference results with samples (if available) or params.

  • true_params (dict, optional) – True parameter values (if known) to overlay on plot.

  • ax (matplotlib.axes.Axes, optional) – Axes object to plot on. If None, creates new figure.

  • n_draws (int) – Number of posterior samples to show as thin lines. Default 50.

  • color (str) – Color for posterior lines. Default “C0” (first color in style).

  • label (str) – Label for posterior. Default “Posterior”.

Returns:

ax – The matplotlib Axes object with the plot.

Return type:

matplotlib.axes.Axes

Notes

Shows 16th and 84th percentiles as filled region, with individual sample curves in light color. If true_params provided, shows truth in black with dashed line for smooth SFH (parametric part).

Examples

>>> result = model.fit(flux, noise)
>>> ax = model.plot_sfh_posterior(result)
>>> ax.set_yscale("log")
precompute_spectroscopy(wave_obs)[source]

Pre-interpolate SSP templates to observed wavelength grid.

Call this before spectroscopic fitting to get a ~20x speedup. Requires fixed redshift.

Parameters:

wave_obs (array, shape (n_pix,)) – Observed wavelength grid (Angstrom).

Returns:

The same model instance, with _state.precomputed.spectroscopy and _state.wave_obs updated immutably via dataclasses.replace. Returned for chaining; both model.precompute_spectroscopy(wave); model.fit(data) and fit = model.precompute_spectroscopy(wave).fit(data) work.

Return type:

self

Notes

State internals (_precomputed, _state) are frozen dataclasses; this method swaps them via dataclasses.replace rather than mutating their fields. The model object itself is the controller and is not replaced — see Phase 3 of the refactor: cached compiled artefacts are dropped via clear_model_cache(self) and lazily rebuilt on next use.

Caches precomputed SSP spectra at the fixed redshift, enabling ~10-20× speedup for repeated spectrum predictions. Requires fixed redshift (raises ValueError otherwise).

Examples

>>> wave_obs = np.linspace(3500, 7000, 2000)
>>> model.precompute_spectroscopy(wave_obs)
>>> flux = model.predict_spectrum(params)
precompute_ztable(z_grid=None, z_min=0.001, z_max=3.0, n_z=100)[source]

Pre-compute SSP photometry on a redshift grid for free-z fitting.

At inference time, the precomputed table is interpolated to the current z — same speedup as fixed-z precomputation, but z is free. Follows the DSPS precompute_ssp_obsmags_on_z_table approach.

Parameters:
  • z_grid (array, optional) – Custom redshift grid. If None, uses linspace(z_min, z_max, n_z).

  • z_min (float) – Minimum redshift (default 0.001).

  • z_max (float) – Maximum redshift (default 3.0).

  • n_z (int) – Number of grid points (default 100). More points = more accurate interpolation. 100 gives <0.01% interpolation error for smooth filter transmission curves.

Returns:

The same model instance, with _state.precomputed.photometry_ztable updated immutably via dataclasses.replace. Returned for chaining; cache cleared via clear_model_cache(self).

Return type:

self

Notes

Enables fast photometry prediction with free redshift (no fixed z). Interpolates precomputed SSP×filter grid to current z at inference time, achieving similar speedup as fixed-z precomputation.

Examples

>>> model.precompute_ztable(z_min=0.01, z_max=4.0, n_z=200)
>>> flux = model.predict_photometry(params)  # z now free
predict(params)[source]

Create a lazy prediction object for all derived physical quantities.

Returns a Prediction object that computes and caches derived quantities on first access. This is the recommended API for interactive exploration of a single galaxy’s properties, trading speed for convenience.

For batch computation over posterior chains or mock catalogs, use the JIT-compatible methods predict_sfh_quantities(), predict_sed_quantities(), or predict_line_luminosities() with jax.vmap() instead (up to 1000× faster for large batches).

Parameters:

params (dict) – Parameter values using public parameter names.

Returns:

Lazy caching wrapper with property groups:

  • .sfh : SFH-derived quantities (stellar mass, SFR, age, metallicity)

  • .sed : SED-derived quantities (luminosities, colors, indices)

  • .lines : Emission line properties (luminosities, fluxes, ratios)

  • .radio : Radio SED properties (if radio=True)

  • .xray : X-ray SED properties (if xray=True)

  • .ionizing : Ionizing photon budget properties

Return type:

Prediction

Notes

Not JIT-compatible: Uses Python-side caching and object attribute access. Useful for interactive exploration, not for inference loops. For inference, use predict_sfh_quantities(), predict_sed_quantities(), etc. with jax.vmap().

Lazy evaluation: Quantities are computed only when accessed. Repeated access to the same property reuses cached results. This is transparent to the user.

NaN handling: Some quantities (e.g., stellar_mass_surviving, l_dust_absorbed) may return NaN if required data/parameters unavailable (e.g., no mass-remaining table, dust_model=’none’). The Prediction object handles NaN gracefully (returns None when data required to compute the quantity is absent).

Examples

Single-galaxy exploration (lazy, on-demand):

>>> pred = model.predict(params)
>>> pred.sfh.stellar_mass  # triggers SFH computation, caches result
Array(1.23e10, dtype=float64)
>>> pred.sfh.mass_weighted_age_gyr  # reuses cached SFH
Array(2.34, dtype=float64)
>>> pred.sed.l_bol  # triggers SED computation
Array(2.5e10, dtype=float64)
>>> pred.sed.uv_slope_beta  # reuses cached SED
Array(-1.8, dtype=float64)
>>> pred.lines.halpha  # triggers nebular computation
Array(4.23e-15, dtype=float64)

Batch computation (JIT-compatible, faster for large N):

>>> import jax
>>> params_batch = spec.sample(jax.random.PRNGKey(0), n=10000)
>>> sfh_fn = jax.vmap(model.predict_sfh_quantities)
>>> sfh_batch = sfh_fn(params_batch)
>>> sfh_batch.stellar_mass  # shape (10000,)
>>> sfh_batch.stellar_mass.mean()

See also

predict_sfh_quantities

JIT-compatible SFH quantities for batch.

predict_sed_quantities

JIT-compatible SED quantities for batch.

predict_line_luminosities

JIT-compatible emission lines for batch.

predict_rest_sed

Full rest-frame SED for custom analysis.

predict_derived(params)[source]

Compute derived physical quantities as a flat dict.

Convenience wrapper around predict() that extracts the key SFH-derived scalars into a plain dict. Use predict() for lazy on-demand access to all quantities, or predict_sfh_quantities() for JIT-compatible batch computation.

Parameters:

params (dict) – Parameter values.

Returns:

“stellar_mass”: total mass formed [M_sun] “stellar_mass_surviving”: surviving mass in living stars +

remnants [M_sun] or None if mass-remaining table not loaded.

”sfr_100myr”: SFR averaged over last 100 Myr [M_sun/yr] “sfr_10myr”: SFR averaged over last 10 Myr [M_sun/yr] “ssfr”: specific SFR [yr^-1], uses surviving mass if

available, else formed mass.

Return type:

dict with keys

Notes

JIT-compatible: no — wraps predict().

Convenience wrapper around the lazy predict() object. For batch operations, use predict_sfh_quantities() directly with jax.vmap().

predict_emission_lines_via_orchestrator(params)[source]

Phase II-2.6 orchestrator-path emission-line luminosities.

Returns:

11 standard survey-diagnostic lines (lya, civ_1549, oii, hbeta, oiii_4959/5007, nii_6548/6584, halpha, sii_6717/6731). Returns all-NaN when the active nebular backend does not publish a discrete line catalogue (BakedIn, shock).

Return type:

EmissionLines

predict_hbeta(params: dict) float[source]

Predict Hβ luminosity for use with CLOUDY-informed emission line priors.

Required by marginalize_emission_lines_cloudy() as the l_hbeta argument, which scales CLOUDY’s ratio-relative-to-Hβ priors to physical units.

Hβ luminosity is computed via the Case B recombination approximation (Leitherer et al. 1999):

\[L_{H\beta} \approx 5.22 \times 10^7 \times \text{SFR}_{10} \; [L_\odot]\]

where \(\text{SFR}_{10}\) is the SFR averaged over the last 10 Myr (the ionizing-photon relevant timescale), derived from Q_H ≈ 4.2 × 10⁵³ × SFR [photons/s] and L_Hβ = 4.76 × 10⁻¹³ × Q_H erg/s converted to L_sun.

Parameters:

params (dict) – Model parameters (from spec.sample() or a Posterior).

Returns:

Hβ luminosity [Lsun].

Return type:

float

Examples

>>> l_hbeta = model.predict_hbeta(params)
>>> ln_L = marginalize_emission_lines_cloudy(
...     residual,
...     noise,
...     A,
...     log_z=params["met_logzsol"],
...     neb_logU=-3.0,
...     l_hbeta=l_hbeta,
... )

Notes

JIT-compatible: no — wraps predict_sfh_quantities().

Uses Case B recombination coefficients (Leitherer et al. 1999 [1]_). If SFH computation fails (e.g., invalid params), returns safe fallback of 1 L_sun.

See also

predict_sfh_quantities

JIT-compatible SFH quantities including sfr_10myr.

References

predict_ionizing_quantities_via_orchestrator(params)[source]

Phase II-2.6 orchestrator-path ionizing-photon quantities.

Returns:

q_h, xi_ion.

Return type:

IonizingQuantities

predict_line_fluxes(params, target_wavelengths=None, tolerance_aa=5.0)[source]

Predict observed emission line fluxes.

Calls the nebular backend to compute line luminosities, selects target lines by wavelength matching, and converts from luminosity (Lsun) to observed flux (erg/s/cm^2).

Parameters:
  • params (dict) – Parameter values (public names).

  • target_wavelengths (array, shape (n_target,), optional) – Rest-frame vacuum wavelengths (Angstrom) of lines to predict. Each wavelength is matched to the nearest backend line. If None, returns all lines from the nebular backend.

  • tolerance_aa (float or None, default 5.0) – Maximum allowed wavelength delta [Angstrom] between a requested target and the matched catalogue line. Raises ValueError on any miss, listing the offending targets. Pass None to disable (recovers legacy nearest-line-no-matter-what behaviour).

Returns:

fluxes – Observed line fluxes in erg/s/cm^2.

Return type:

array, shape (n_target,) or (n_all_lines,)

Raises:

ValueError – If no nebular backend is configured.

Notes

JIT-compatible: no — delegates to nebular backend.

Observed flux is calculated from luminosity via:

\[F = \frac{L_{\odot}}{4\pi d_L^2}\]

where \(d_L\) is the luminosity distance.

predict_luminosity(params)[source]

Compute rest-frame luminosity SED in solar units.

Parameters:

params (dict) – Parameter values using public parameter names.

Returns:

Rest-frame luminosity [L_sun/Hz].

Return type:

array, shape (n_wave,)

Notes

JIT-compatible: no — wraps predict_rest_sed().

Divides rest-frame SED by \(L_{\odot} = 3.828 \times 10^{33}\) erg/s (IAU 2015 definition).

predict_magnitudes(params)[source]

Compute observed AB magnitudes through all filters.

Parameters:

params (dict) – Parameter values using public parameter names.

Returns:

magnitudes – Observed AB magnitudes [mag].

Return type:

ndarray, shape (n_filters,)

Notes

JIT-compatible: yes (via predict_photometry or predict_luminosity).

Uses dsps.calc_obs_mag() when available (cosmology-aware), falls back to conversion from photometric flux otherwise.

predict_obs_sed(params, wave=None)[source]

Compute observed-frame SED (redshifted + IGM + DLA transmission).

Evaluates the rest-frame SED, redshifts to observed frame (wavelength × (1+z)), and applies IGM and DLA absorption where configured. At z=0, identical to predict_rest_sed().

Parameters:
  • params (dict) – Parameter values using public parameter names.

  • wave (array, optional) – Custom rest-frame wavelength grid [Angstrom] before redshifting. If None, uses model default.

Returns:

NamedTuple with:

  • wavelength : array, shape (n_wave,). Observed-frame wavelength [Ångstrom]

  • sed : array, shape (n_wave,). Observed-frame spectral luminosity density [erg/s/Hz]

Return type:

SEDResult

Notes

JIT-compatible: no — delegates to predict_rest_sed().

IGM absorption: Applies transmission via \(T_{\mathrm{IGM}}(\lambda_{\mathrm{obs}}, z)\) when igm=True in spec. Uses Inoue+2014 [1]_ mean IGM with optional extensions for:

  • Reionization epoch: CGM damping wing (Asada+2025 [2]_)

  • Patchy reionization: parameterized neutral fraction (Mason+2018 [3])

CRITICAL GOTCHA: IGM transmission takes observed-frame wavelengths as input. The redshifted wavelength in this SED is already in observed frame, so igm_transmission(wave_obs, z) is called correctly.

DLA absorption: Applies Lyman-series damping wing when dla=True. Parameterized by neutral column density log₁₀(N_HI) and temperature. See dla_transmission_obs().

Physical units:

  • Wavelength: observed-frame Ångstrom (redshifted)

  • SED: erg/s/Hz (same as rest-frame), but now at redshifted wavelengths and reduced intensity by \((1+z)\) factor from cosmological redshift

Examples

>>> sed_obs = model.predict_obs_sed(params)
>>> # IGM and redshift already applied
>>> print(f"z={params['redshift']}: wavelength {sed_obs.wavelength[0]:.0f} Å")

See also

predict_rest_sed

Rest-frame SED (before redshift/IGM).

predict_photometry

Filter-integrated observed flux (uses this internally).

References

predict_photometry(params, mode='auto', approx=None)[source]

Compute observed photometric flux densities through all filters.

Convolves the SED (redshifted and IGM-absorbed) through filter transmission curves, returning flux densities in the AB system at the source. Supports three prediction modes for speed/accuracy tradeoff: compositional (exact, XLA-fused), hybrid (precomputed stellar + exact non-stellar), and exact (full pipeline, slowest).

Parameters:
  • params (dict) – Parameter values using public parameter names (e.g., sfh_tsnorm_log_peak_sfr, met_logzsol, redshift). See Parameters for canonical names.

  • mode (str, optional) –

    Prediction strategy. Default "auto" selects fastest available.

    • "auto" — cascade through available: compositional → hybrid → exact

    • "compositional" — full-resolution JIT SED kernel (bit-exact, fastest). All components evaluated at full wavelength, integrated through filters in single XLA-fused graph. Preferred when available.

    • "hybrid" — precomputed SSP×filter photometry (stellar, ~0.4% error) + exact non-stellar (emission, AGN, dust) at full wavelength, integrated through filters. Fallback when compositional unavailable (e.g., variable-redshift, evolving metallicity, tabulated SFH).

    • "exact" — raw forward pipeline, no kernel JIT, no precomputation. Reference accuracy, slowest (~5–10× slower than compositional).

  • approx (bool, optional) – Maps Truemode="auto" and Falsemode="exact". Prefer passing mode= directly.

Returns:

flux_density – Observed flux densities in erg/s/cm²/Hz (AB system, rest-frame reference frame corrected for luminosity distance and (1+z) redshift factor).

Return type:

array, shape (n_filters,)

Raises:

ValueError – If no filters configured in the model (pass filters or observation= to constructor).

Notes

JIT-compatible: yes — compositional and hybrid modes are JIT’d at initialization. Exact mode is not JIT’d. All modes are safe inside jax.grad() for parameter gradients.

Approximate accuracy: Compositional and hybrid modes produce predictions within 0.1%–0.4% of exact (see CLAUDE.md for mode-specific tolerances). Differences driven by:

  • Compositional: None (bit-exact vs. exact)

  • Hybrid: ~0.4% stellar photometry (Zacharegkas+2025 [1]_)

  • Approximations enabled via approx: see SEDModel for individual component tolerances

Filter wavelengths: All filters loaded via load_filter_set() or Photometry are assumed to be in observed frame (redshifted). The model auto-redshifts rest-frame SED by \((1+z)\) before filter integration.

See also

predict

Lazy prediction object for all derived quantities.

predict_spectrum

Spectral flux at arbitrary wavelengths.

predict_magnitudes

AB magnitudes (uses photometry internally).

Examples

>>> flux = model.predict_photometry(params)
>>> mags = model.predict_magnitudes(params)
>>> flux_exact = model.predict_photometry(params, mode="exact")

References

predict_photometry_batch(params_batch)[source]

Compute photometry for a batch of parameter sets via jax.vmap.

Parameters:

params_batch (dict of arrays) – Each value has shape (N, …) with leading batch dimension.

Returns:

Photometric flux for each galaxy.

Return type:

array, shape (N, n_filters)

Notes

JIT-compatible: yes — uses jax.vmap() over predict_photometry().

Examples

>>> import jax
>>> key = jax.random.PRNGKey(0)
>>> params_batch = {
...     k: jnp.tile(v[None], (100,) + (1,) * (len(v.shape)))
...     for k, v in posterior.samples.items()
... }
>>> flux_batch = model.predict_photometry_batch(params_batch)
predict_photometry_via_orchestrator(params)[source]

Photometry through the orchestrator path.

Runs the SEDComponent chain on the model’s configuration, then projects the resulting rest-frame SED through every filter in self.observation.photometry. Returns flux densities in the AB system at the source.

Parameters:

params (Mapping) – Free-parameter dict (same shape as predict_via_orchestrator()).

Returns:

flux_density – Observed flux densities [erg/s/cm²/Hz].

Return type:

ndarray, shape (n_filters,)

Raises:

ValueError – If no photometric filters are configured on the observation.

Notes

JIT-compatible: yes — uses jax.jit()-friendly tengri.observation.photometry.compute_flux_density() per filter.

Differs from the legacy predict_photometry(): this path goes through the SEDComponent orchestrator (no fused kernel dispatch); for inference workflows where you compile once and run thousands of times, the warm latency is equivalent (~2 ms). For one-shot photometry the legacy path with its tier-1/tier-2 fast paths is still faster.

predict_radio_quantities_via_orchestrator(params)[source]

Phase II-2.6 orchestrator-path radio quantities.

Returns:

l_1p4ghz, l_thermal, l_nonthermal, q_ir. Fields are NaN if the configured chain has no RadioSEDComponent.

Return type:

RadioQuantities

predict_rest_sed(params, wave=None)[source]

Compute rest-frame panchromatic SED luminosity spectrum.

Evaluates all stellar populations, emission (nebular, AGN), and multi-wavelength (radio, X-ray) components in rest-frame coordinates. Returns the total SED integrated across the age distribution set by the SFH and stellar mass parameters.

Parameters:
  • params (dict) – Parameter values using public parameter names.

  • wave (array, optional) – Custom rest-frame wavelength grid [Angstrom]. If None, uses the model’s default: SSP wavelength grid (ssp_data.ssp_wave), or auto-extended grid if radio=True or xray=True in spec.

Returns:

NamedTuple with:

  • wavelength : array, shape (n_wave,). Rest-frame wavelength [Ångstrom]

  • sed : array, shape (n_wave,). Spectral luminosity density [erg/s/Hz]

Return type:

SEDResult

Notes

JIT-compatible: no — computes SED components via the orchestrator path (predict_via_orchestrator()) which is not JIT’d. For JIT-compatible SED access, use predict_sed_quantities() instead.

Physical units:

  • Wavelength: rest-frame Ångstrom (not redshifted)

  • SED: erg/s/Hz (L_ν), normalized to the total stellar mass implied by the SFH

SED components: Total SED is the sum of:

  • Stellar continuum (CSP from SSP integration)

  • Nebular continuum (if nebular_mode ≠ ‘baked-in’)

  • Nebular emission lines (if neb_* params free)

  • AGN continuum (if agn_model set)

  • Dust attenuation (applied to stellar + AGN)

  • Dust emission (re-radiated IR, if dust_emission_model set)

  • Shock emission (if shock=True)

  • Radio/X-ray (if radio=True or xray=True)

Attenuation: Applied via two-component (birth cloud + diffuse ISM) or single-screen dust law, parameterized by age-dependent optical depth. See components.dust for available laws.

Examples

>>> sed = model.predict_rest_sed(params)
>>> import matplotlib.pyplot as plt
>>> plt.loglog(sed.wavelength, sed.sed)
>>> plt.xlabel("Rest-frame wavelength (Angstrom)")
>>> plt.ylabel("SED (erg/s/Hz)")

See also

predict_obs_sed

Observed-frame SED (redshifted + IGM).

predict_sed_quantities

JIT-compatible SED-derived quantities.

predict_sed_quantities(params)[source]

Compute SED-derived quantities in JIT-compatible form.

Evaluates the full forward model and computes UV slope, spectral indices (D4000, Balmer break), bolometric/IR luminosities, dust attenuation, and luminosity-weighted age/metallicity. Returns a SEDQuantities NamedTuple that is fully JIT-compatible and vmap-ready for batch inference.

Parameters:

params (dict) – Parameter values using public parameter names.

Returns:

NamedTuple with fields:

  • l_bol : float. Bolometric luminosity [L☉]

  • l_tir : float. Total infrared (8–1000 μm) luminosity [L☉]

  • l_dust_absorbed : float. Dust-absorbed luminosity [L☉] (intrinsic − attenuated), or NaN if intrinsic SED unavailable.

  • irx : float. Infrared excess := L_TIR / L_UV(1600 Å). Common probe of dust obscuration (Dale et al. 2001).

  • uv_slope_beta : float. UV slope (power-law index) in f_λ ∝ λ^β for 1200–2600 Å.

  • dn4000 : float. D_n(4000) break ratio: flux average at 3750–3950 Å / 4050–4250 Å. Indicator of stellar age.

  • balmer_break : float. Balmer break: flux ratio ~3700 Å / ~4000 Å. Old stellar population signature.

  • m_uv : float. Absolute magnitude at 1500 Å (M_1500, standard reionization-era indicator).

  • fuv_flux : float. Flux at 1500 Å [erg/s/cm²]

  • nuv_flux : float. Flux at 2300 Å [erg/s/cm²]

  • fuv_flux_intrinsic : float. FUV flux, dust-free (intrinsic SED). NaN if unavailable.

  • nuv_flux_intrinsic : float. NUV flux, dust-free. NaN if unavailable.

  • rest_uv_color : float. Rest-frame UV color (f_1500 − f_2300).

  • luminosity_weighted_age_gyr : float. Luminosity-weighted age [Gyr] (∫L_λ age dλ / ∫L_λ dλ).

  • luminosity_weighted_metallicity : float. Luminosity-weighted log₁₀(Z/Z☉) or absolute log₁₀(Z).

Return type:

SEDQuantities

Notes

JIT-compatible: yes — all operations use jnp primitives. Safe inside jax.jit(), jax.vmap(), and jax.grad().

Gradient-safe: yes — all quantities are differentiable w.r.t. SFH, metallicity, and dust parameters.

Spectral indices: Computed directly on the rest-frame SED (not broadband-filtered). All wavelengths defined in rest frame.

Dust-absorbed luminosity: Defined as L_dust = L_intrinsic − L_attenuated (i.e., the energy re-radiated in the IR). Requires the forward model to track both intrinsic and attenuated SEDs internally. Returns NaN if dust_model="none" or intrinsic SED not available.

Luminosity-weighted quantities: Computed as:

\[\langle Q \rangle_L = \frac{\int L_\lambda(\lambda) Q(\lambda) d\lambda} {\int L_\lambda(\lambda) d\lambda}\]

More sensitive to young, UV-bright populations than mass-weighted age.

Examples

Single galaxy:

>>> sed_q = model.predict_sed_quantities(params)
>>> sed_q.l_bol
Array(2.5e10, dtype=float64)
>>> sed_q.dn4000
Array(1.42, dtype=float64)
>>> sed_q.irx
Array(1.87, dtype=float64)

Batch over posterior samples:

>>> import jax
>>> sed_fn = jax.vmap(model.predict_sed_quantities)
>>> sed_batch = sed_fn(params_batch)
>>> sed_batch.m_uv  # shape (n_samples,)
>>> sed_batch.dn4000.mean()

Computing IRX − β relation:

>>> sed_q = sed_fn(params_batch)
>>> irx = sed_q.irx
>>> beta = sed_q.uv_slope_beta
>>> # Compare to Meurer et al. (1999) IRX-β calibration

See also

predict

Lazy prediction for single-galaxy exploration.

predict_sfh_quantities

JIT-compatible SFH quantities.

predict_rest_sed

Full rest-frame SED (for custom analysis).

predict_sed_quantities_via_orchestrator(params)[source]

Drop-in replacement for predict_sed_quantities().

Returns:

15-field NamedTuple matching the legacy contract.

Return type:

SEDQuantities

predict_sfh(params, n_linear=1000)[source]

Compute SFH on uniform linear-time grid for visualization.

Evaluates the SFH parameterization at n_linear evenly-spaced points in lookback time, returning both the smooth parametric component (sfr_mean) and the full SFH including GP-field modulation (sfr_full, if stochastic SFH enabled).

Parameters:
  • params (dict) – Parameter values using public parameter names.

  • n_linear (int, optional) – Number of output grid points, evenly spaced in lookback time. Default 1000 (sufficient for smooth visualization).

Returns:

  • "t_gyr" : array, shape (n_linear,). Lookback time [Gyr], from 0 (now) to ~13.8 (Big Bang).

  • "sfr_mean" : array, shape (n_linear,). Parametric mean SFR [M☉/yr] (no GP modulation).

  • "sfr_full" : array, shape (n_linear,). Full SFH including GP field [M☉/yr]. Identical to sfr_mean if stochastic SFH not enabled.

Return type:

dict with keys

Notes

JIT-compatible: no — uses Python-side interpolation. For JIT-compatible SFH evaluation, use predict_sfh_quantities() to get integrated quantities (stellar mass, age, etc.).

Time grid: Output is on a uniform linear-time (lookback) grid, not the internal log-age grid. This makes visualization cleaner and suitable for plotting.

SFH mean vs. full: When correlated-field (stochastic) SFH is enabled, sfr_mean shows the smooth parametric trend (e.g., exponential decline), while sfr_full adds GP modulation for realistic burstiness. If parametric-only SFH is used, they are identical.

Physical units: Output SFR is in M☉/yr. Lookback time is in Gyr (cosmic time before today).

Examples

>>> sfh = model.predict_sfh(params)
>>> print(sfh.keys())
dict_keys(['t_gyr', 'sfr_mean', 'sfr_full'])
>>> import matplotlib.pyplot as plt
>>> plt.plot(sfh["t_gyr"], sfh["sfr_mean"], label="Smooth")
>>> if "sfr_full" in sfh:
...     plt.plot(sfh["t_gyr"], sfh["sfr_full"], alpha=0.5, label="With bursts")

See also

predict_sfh_quantities

Integrated SFH quantities (JIT-compatible).

predict

Lazy access to SFH and all derived quantities.

predict_sfh_quantities(params)[source]

Compute SFH-derived quantities in JIT-compatible form.

Integrates the SFH to compute stellar mass, recent SFR, specific SFR, and mass-weighted age/metallicity. Returns a SFHQuantities NamedTuple that is fully JIT-compatible and vmap-ready for batch inference over posterior chains or mock catalogs.

Parameters:

params (dict) – Parameter values using public parameter names.

Returns:

NamedTuple with fields:

  • stellar_mass : float. Total stellar mass formed [M☉]

  • stellar_mass_surviving : float. Mass in living stars + remnants [M☉], or NaN if SSP mass-remaining tables not loaded.

  • sfr_100myr : float. SFR time-averaged over last 100 Myr [M☉/yr]

  • sfr_10myr : float. SFR time-averaged over last 10 Myr [M☉/yr]

  • ssfr : float. Specific SFR (SFR/M_surv or SFR/M_formed) [yr⁻¹]

  • mass_weighted_age_gyr : float. Mass-weighted age [Gyr]

  • mass_weighted_metallicity : float. Mass-weighted log₁₀(Z/Z☉) or absolute log₁₀(Z) depending on metallicity mode

Return type:

SFHQuantities

Notes

JIT-compatible: yes — all operations use jnp primitives. Safe inside jax.jit(), jax.vmap(), and jax.grad().

Gradient-safe: yes — all quantities are differentiable w.r.t. SFH and metallicity parameters.

Surviving mass: Requires SSP grid with ssp_mass_remaining (e.g., FSPS grids). If unavailable, returns NaN. predict() handles NaN gracefully when the quantity is unavailable.

SFR averaging: Time-weighted mean over lookback-time window:

\[\langle\mathrm{SFR}\rangle_T = \frac{\sum_i \mathrm{SFR}_i \Delta t_i}{\sum_i \Delta t_i}\]

where \(i\) ranges over all ages \(\leq T\). Uses symmetric bin widths (jnp.gradient) to avoid trapezoid boundary artifacts.

Mass-weighted age: Computed as

\[t_\mathrm{mw} = \frac{\sum_i w_i t_i}{\sum_i w_i}\]

where \(w_i\) are stellar population weights (age-integrated SFR).

Examples

Single galaxy:

>>> sfh = model.predict_sfh_quantities(params)
>>> sfh.stellar_mass
Array(1.23e10, dtype=float64)

Batch over 10,000 posterior samples:

>>> import jax
>>> sfh_fn = jax.vmap(model.predict_sfh_quantities)
>>> sfh_batch = sfh_fn(params_batch)
>>> sfh_batch.stellar_mass  # shape (10000,)
>>> print(sfh_batch.stellar_mass.mean())

See also

predict

Lazy prediction for single-galaxy exploration (non-JIT).

predict_sfh

SFH on linear-time grid for visualization.

predict_sed_quantities

JIT-compatible SED quantities.

predict_sfh_quantities_via_orchestrator(params)[source]

Drop-in replacement for predict_sfh_quantities().

Routes through the orchestrator and converts the resulting PipelineState to a legacy SFHQuantities NamedTuple via tengri.forward.state_to_sfh_quantities(). Same return shape as the legacy method, computed via the Phase II SEDComponent path.

Returns:

7-field NamedTuple matching the legacy contract.

Return type:

SFHQuantities

predict_spectral_indices(params, index_defs, mode='_traceable')[source]

Predict spectral index values from the model SED.

Generates a rest-frame spectrum covering the index wavelength ranges and measures each index (EW or break ratio).

Parameters:
  • params (dict) – Parameter values (public names).

  • index_defs (tuple of SpectralIndexDef) – Index definitions to measure.

  • mode (str, optional) – Forward model prediction mode.

Returns:

Predicted index values.

Return type:

jnp.ndarray, shape (n_indices,)

Notes

JIT-compatible: depends on mode ("_traceable" by default).

Measures spectral indices (equivalent width or break ratio) from a rest-frame spectrum covering all wavelength ranges in index_defs.

predict_spectrum(params, wave_obs=None, mode='auto', approx=None, wave_chunk_size=None)[source]

Compute observed spectrum at given wavelengths with LSF convolution.

Evaluates the full SED at custom wavelengths in observed frame, applies velocity dispersion broadening (if sigma_v in spec), convolves with instrument line-spread function, and optionally applies multiplicative Chebyshev calibration polynomial.

Parameters:
  • params (dict) – Parameter values using public parameter names.

  • wave_obs (array, optional) –

    Observed-frame wavelength grid [Angstrom]. If None, uses:

    1. Grid from precompute_spectroscopy() if called

    2. Grid from observation.spectroscopy.wave_obs if set

    3. Raises ValueError if neither available

  • mode (str, optional) – Prediction mode (same as predict_photometry()). Default "auto" cascades through available kernels.

  • approx (bool, optional) – Maps Truemode="auto" and Falsemode="exact". Prefer passing mode= directly.

  • wave_chunk_size (int, optional) – If specified, split observed-frame wavelength axis into chunks of this size and evaluate via jax.lax.map to reduce per-chunk HLO size for XLA compilation. Default None (no chunking, exact behavior). For spectroscopy with R~500 at N≥64 galaxies, typical value is 32–64 to avoid XLA compilation wall-clock.

Returns:

flux – Observed spectral flux density [erg/s/cm²/Hz] in the AB system at the specified wavelengths.

Return type:

array, shape (n_pix,)

Raises:

ValueError – If wave_obs is None and no precomputed wavelength grid available.

Notes

JIT-compatible: compositional and hybrid modes are JIT’d. Exact mode is not JIT’d.

Velocity dispersion: When sigma_v is in free params, applies line-of-sight broadening via Gaussian convolution at FWHM = 2.355 × sigma_v. Implemented as wavelength-space Gaussian convolution (valid for linear pixels; use apply_lsf() for log-wavelength pixels).

Line-spread function: Composition of:

  • Velocity dispersion broadening (σ_v-dependent)

  • Instrument LSF (resolution R-dependent, Gaussian approximation)

  • Chebyshev multiplicative calibration (optional)

All three are convolved in the forward model.

Precomputed wavelength grid: For fixed-redshift models with fixed wavelength grid, call precompute_spectroscopy(wave_obs)() at initialization to cache spectroscopy kernels. This enables the hybrid/compositional paths for ~10× speedup vs. exact.

Wavelength-axis chunking: Set wave_chunk_size to split the observed-frame wavelength axis into ~N/chunk_size chunks and evaluate independently via lax.map. Each chunk’s HLO is ~1/K of the full HLO (K = chunk_size / min_chunk_width), reducing XLA compile-time superlinearly. Numerical output is bitwise-identical to unchunked. Typical runtime overhead: +5–20% per galaxy due to map overhead.

Examples

>>> wave_obs = np.linspace(4000, 5500, 1000)  # observed frame [Å]
>>> flux = model.predict_spectrum(params, wave_obs)
>>> import matplotlib.pyplot as plt
>>> plt.plot(wave_obs, flux)
>>> plt.xlabel("Wavelength (Å)")
>>> plt.ylabel("Flux (erg/s/cm²/Hz)")

For large spectroscopy sets with many galaxies, use chunking:

>>> flux = model.predict_spectrum(params, wave_obs, wave_chunk_size=64)

See also

predict_photometry

Filter-integrated flux (simpler, faster).

predict

Lazy access to all SED and SFH quantities.

precompute_spectroscopy

Cache spectroscopy kernels for this grid.

predict_spectrum_batch(params_batch)[source]

Compute spectra for a batch of parameter sets via jax.vmap.

Parameters:

params_batch (dict of arrays) – Each value has leading batch dimension.

Returns:

Spectral flux for each galaxy.

Return type:

array, shape (N, n_pix)

Notes

JIT-compatible: yes — uses jax.vmap() over predict_spectrum().

Examples

>>> params_batch = {
...     k: jnp.tile(v[None], (1000,) + (1,) * (len(v.shape)))
...     for k, v in posterior.samples.items()
... }
>>> flux_batch = model.predict_spectrum_batch(params_batch)
>>> flux_batch.shape
(1000, n_pix)
predict_spectrum_via_orchestrator(params, wave_obs=None)[source]

Spectrum through the orchestrator path.

Runs the SEDComponent chain, applies the cosmological redshift + luminosity-distance projection, interpolates onto wave_obs, and (if configured) applies the instrument LSF + velocity-dispersion broadening. Mirrors the contract of the legacy predict_spectrum()’s observed-frame output but goes through the SEDComponent chain rather than the fused kernel.

Parameters:
  • params (Mapping) – Free-parameter dict (same shape as predict_via_orchestrator()).

  • wave_obs (array_like, shape (n_pix,), optional) – Observed-frame wavelength grid [Angstrom]. If None, falls back to the precomputed grid (self._wave_obs or self._precomputed.spectroscopy.wave_obs_pixels).

Returns:

flux – Observed-frame spectral flux density [erg/s/cm^2/Hz].

Return type:

ndarray, shape (n_pix,)

Raises:

ValueError – If no wave_obs grid is supplied or precomputed.

Notes

JIT-compatible: yes — run_components(), the rest→obs projection in observe_spectrum_from_rest_sed(), and apply_lsf() are all JIT-friendly. No calibration polynomial is applied; callers that need calibration should compose it on top via the user-likelihood Protocol path.

predict_via_orchestrator(params)[source]

Forward pass via the Phase II SEDComponent orchestrator.

Builds a component chain from this model’s structural settings (self.spec + self.ssp_data + dust / nebular / AGN / radio / X-ray / IGM flags) and threads params through tengri.forward.run_components(). Returns the final tengri.core.PipelineState, not a legacy Prediction — callers wanting the legacy shape should keep using predict_photometry()/predict_spectrum() until the full integration adapter ships.

This is the public bridge between SEDModel’s configuration surface and the orchestrator: it lets users with an existing SEDModel go through run_components without re-typing the chain at every call site.

Parameters:

params (Mapping) – Free parameters keyed by canonical name (sfh_*, met_*, dust_*, agn_*, radio_*, xray_*, igm_*, redshift).

Returns:

Threaded state after the chain runs. sed_intrinsic is the rest-frame total SED in erg/s/Hz; sed_observed is populated when an IGM component is present; derived carries every cross-component publication (L_ir, L_agn_bol, log_mstar, lnu_age, etc.).

Return type:

PipelineState

Notes

JIT-compatible: yes — run_components() and every adapter’s apply are pure JAX.

self.spec.mean_sfh_type is a list (e.g. ["tsnorm"], ["dpl", "field"]); the first entry is the mean SFH model, and "field" anywhere in the list enables the GP-field branch. Anything else (burst, etc.) is currently unmapped and will raise downstream.

predict_xray_quantities_via_orchestrator(params)[source]

Phase II-2.6 orchestrator-path X-ray quantities.

Returns:

l_x_xrb, l_x_agn, l_x_total.

Return type:

XRayQuantities

prior_predictive(n: int = 500, seed: int = 42) PriorPredictive[source]

Sample from the prior and evaluate forward model on each draw.

Parameters:
  • n (int) – Number of prior samples. Default 500.

  • seed (int) – Random seed. Default 42.

Returns:

Object containing flux, SFH, and parameter draws with model reference.

Return type:

PriorPredictive

Notes

Useful for prior predictive checks: visualizing what the model predicts under the prior without conditioning on data.

Examples

>>> pp = model.prior_predictive(n=100, seed=42)
>>> # Access photometry, SFH, and parameters from the prior
recommend_method() str[source]

Return the recommended inference method string for this model.

Returns:

Canonical method name for Fitter.run() or model.fit().

Return type:

str

Notes

Based on model dimensionality, complexity, and available precomputation. Use as input to model.fit(method=model.recommend_method()).

Examples

>>> method = model.recommend_method()
>>> result = model.fit(flux, noise, method=method)
summary() str[source]

Return a human-readable summary of the model configuration.

Returns:

Formatted summary showing SSP grid, filters, precomputation, fused kernel status, and enabled components.

Return type:

str

Notes

Similar to tree() but focuses on computational configuration and precomputation status rather than physics parameters.

Examples

>>> print(model.summary())
tree() str[source]

Return a human-readable physics tree showing the model hierarchy.

Shows the active sub-models at each physical layer (SFH, SPS, Dust, Nebular, AGN, Observation), the free parameters at each layer, and the recommended inference method.

Returns:

Multi-line formatted tree string.

Return type:

str

Notes

Useful for inspecting model configuration before fitting or inference.

Examples

>>> print(model.tree())
Model  [D=7, stochastic=False]
...
property wavelengths

Rest-frame wavelength grid (Angstrom).

Returns the SSP grid by default, or the extended panchromatic grid when radio or X-ray emission is enabled.

Returns:

Rest-frame wavelength grid [Angstrom].

Return type:

ndarray, shape (n_wave,)

Notes

This is the grid used by predict_rest_sed() by default when no custom wave= is passed. Updated when radio/X-ray components are added to the model.

Examples

>>> print(model.wavelengths[0], model.wavelengths[-1])
>>> # Default SSP range, e.g. 91.2 to 160000 Å

Parameters

class tengri.Parameters(**kwargs)[source]

Bases: object

Parameter specification defining all model parameters and their priors.

A Parameters object defines the complete parameter set for a SEDModel, including both the mean SFH model(s) and all optional components (dust, nebular emission, AGN, etc.). Parameters can be sampled (for mock data generation) or used as priors for inference.

Parameters are specified as keyword arguments, each of which can be:

  • Scalar (int/float) → Fixed(value) — parameter is constant

  • Tuple (lo, hi) → Uniform(lo, hi) — shorthand for uniform prior

  • Distribution objectUniform, Gaussian, LogUniform, LogNormal, StudentT, or Fixed

A Parameters object also stores model configuration settings (mean_sfh_type, dust_law, nebular mode, etc.) that control which components are enabled. These are not fittable parameters.

Parameters:

**kwargs (keyword arguments) – Model parameters (distribution objects or shorthands) and settings (see “Settings” section below).

mean_sfh_type

Active SFH model(s). Read-only.

Type:

list[str]

n_grid

Grid size for stochastic SFH. Read-only.

Type:

int

stochastic

True if mean_sfh_type includes ‘field’. Read-only.

Type:

bool

all_params

All valid parameter names (free + fixed).

Type:

list[str]

free_params

Non-fixed parameter names (to be inferred).

Type:

list[str]

fixed_params

Fixed parameter names (constants).

Type:

list[str]

n_free

Number of free parameters.

Type:

int

nebular_mode

Nebular emission backend: ‘off’, ‘ssp’, ‘cue’, ‘cloudy’, or ‘cb19’.

Type:

str

dust_model

Dust model: ‘two_component’ or ‘single_component’.

Type:

str

dust_emission

Dust emission template: ‘modified_blackbody’, ‘casey2012’, ‘dale2014’, etc.

Type:

str or None

agn_model

AGN SED model, e.g. ‘kubota_done’, ‘skirtor’, ‘qsogen’.

Type:

str or None

apply_igm

If True, apply IGM absorption (Inoue+2014).

Type:

bool

radio

If True, include radio synchrotron + AGN jet emission.

Type:

bool

xray

If True, include X-ray (XRB + AGN) emission.

Type:

bool

Raises:
  • ValueError – If parameter names are invalid for the selected mean_sfh_type.

  • ValueError – If nebular/dust/AGN settings are mutually incompatible.

Notes

Not JAX-traced: Parameters is the central user-facing object for configuring model parameters and their priors. Parameters objects cannot be created or modified inside a JAX gradient tape (jax.grad, jax.vmap, jax.jit). Create all Parameters objects at the Python level before tracing. Once created, a Parameters object is immutable — use the with_params() method to create modified copies.

Parameter auto-detection: If mean_sfh_type is not explicit, it is inferred from the parameter name prefixes (e.g., ‘sfh_dpl_alpha’ implies ‘dpl’ is active). The inferred type is normalized to a sorted list.

Mirror parameters: A parameter can be tied to another by passing the target name as a string instead of a distribution. Example: neb_logZ_gas="met_logzsol" ties gas metallicity to stellar.

Settings (model configuration, not fittable parameters)

mean_sfh_typestr or list[str]

SFH model(s). Composable: ["dpl", "field"]. Options: dpl, tsnorm, snorm, norm, lnorm, const, exp, dexp, burst, field. Default: ["dpl", "field"].

n_gridint

Grid size for stochastic SFH (latent dimensions). Default: 64.

stochasticbool

DEPRECATED. Use mean_sfh_type with/without ‘field’ instead.

Dust Attenuation Settings

dust_law_bcstr

Attenuation curve for birth cloud. Default: "power_law". Options: power_law, calzetti, kriek_conroy, smc, cardelli, salim, li08.

dust_law_diffstr

Attenuation curve for diffuse ISM. Default: same as dust_law_bc. Can be different for per-component control.

Dust Emission Settings

dust_emissionstr or None

IR emission model. Default: None (disabled). Options: "modified_blackbody", "casey2012", "dale2014", "draine_li2007", "draine_li2014", "dl07_tabulated", "astrodust", "bosa", "themis", "draine2021_pah".

dl07_grid_pathstr

Path to DL07 HDF5 template grid (for "dl07_tabulated").

Nebular Emission Settings

nebular_sspbool

Use SSP files with pre-included nebular emission (wNE files). No free nebular parameters. Default: False.

nebularbool

Enable CLOUDY grid nebular emission. Requires cloudy_grid_path. Default: False.

nebular_cuebool

Enable Cue neural emulator. Default weights loaded automatically. Default: False.

cloudy_grid_pathstr

Path to CLOUDY HDF5 grid. Required when nebular=True.

cue_weights_pathstr

Override default Cue weights path.

neb_ionizationstr

Ionization source for Cue: "ssp" (default), "agn" (future), "ssp+agn" (future).

AGN Settings

agn_modelstr or None

AGN SED model. Default: None (disabled). Options: "simple" (3 params), "standard" (SS73 disc + 2T torus), "kubota_done" (physical disc), "unified_nlr_blr" (NLR/BLR with geometric masking), "qsogen" (empirical quasar, Temple+2021), "skirtor" (clumpy torus RT templates, Stalevski+2016).

Multi-wavelength Settings

radiobool

Enable radio synchrotron + AGN jet emission. Default: False.

xraybool

Enable X-ray (XRB + AGN corona) emission. Default: False.

IGM Settings

apply_igmbool

Apply Inoue+2014 IGM absorption. Default: True.

Metallicity Settings

evolving_metallicitybool

Replace met_logzsol with met_logzsol_0 (old stars) and met_logzsol_final (young stars) for a linear-in-log Z(t) ramp. Default: False.

met_interpstr

Metallicity interpolation method. Default: "smooth". - "smooth": Triweight kernel (same as DSPS, Hearin+2023).

8.5x smoother gradients at <1% speed overhead. Recommended.

  • "linear": 2-point linear in log(Z) (same as FSPS/Prospector).

lgmet_scatterfloat

Triweight kernel bandwidth in dex for met_interp="smooth". Default: 0.1 (DSPS default). Physically: intrinsic Z scatter.

Fittable Parameters (always available)

Parameter

Default

Description

met_logzsol

Uniform(-2, 0.2)

Stellar metallicity log10(Z/Zsun)

met_alpha_fe

Fixed(0.0)

[alpha/Fe] enhancement (dex)

dust_tau_bc

Uniform(0, 4)

Birth cloud V-band optical depth

dust_tau_diff

Uniform(0, 3)

Diffuse ISM V-band optical depth

dust_slope

Fixed(-0.7)

Power-law index (for power_law curve)

dust_f_obscuration

Fixed(0.0)

Unobscured fraction (Lower+2022)

dust_bump_strength

Fixed(0.0)

UV 2175A bump (Kriek&Conroy/Salim)

dust_delta

Fixed(0.0)

Attenuation slope modification

dust_Rv

Fixed(3.1)

R_V (Cardelli curve)

redshift

Fixed(0.1)

Source redshift

noise_frac_cal

Fixed(0.0)

Fractional calibration noise floor

noise_dof

Fixed(0.0)

Student-t degrees of freedom

Conditional Parameters (added when modules enabled)

Nebular (nebular=True):

neb_logU

Fixed(-3.0)

Ionization parameter log10(U)

neb_logZ_gas

Fixed(-0.3)

Gas metallicity (None = tie to stellar)

neb_fesc

Fixed(0.0)

Ionizing photon escape fraction

neb_fesc_lya

Fixed(0.0)

Ly-alpha escape fraction

Dust emission (dust_emission != None):

dust_T

Fixed(35)

Dust temperature (K) for greybody

dust_beta_ir

Fixed(1.6)

Emissivity index

dust_alpha_mir

Fixed(2.0)

MIR slope (Casey 2012)

dust_alpha_dale

Fixed(2.0)

Dale+2014 alpha

dust_umin

Fixed(1.0)

DL07/DL14 minimum radiation field

dust_gamma_dl

Fixed(0.01)

DL07/DL14 PDR fraction

dust_qpah

Fixed(2.5)

DL07/DL14 PAH mass fraction (%)

dust_alpha_dl14

Fixed(2.0)

DL14 radiation field slope (1-3)

dust_eta_balance

Fixed(1.0)

Energy balance deviation factor

AGN (agn_model != None):

agn_frac

Fixed(0.0)

AGN fraction of stellar L_bol

agn_log_lbol

Fixed(10.0)

AGN log L_bol [erg/s] (parametric)

agn_alpha

Fixed(-1.0)

Disc power-law slope

agn_T_torus

Fixed(1000)

Torus temperature (K)

agn_tau_torus

Fixed(5.0)

Torus optical depth at 9.7 um

agn_torus_frac

Fixed(0.5)

Torus covering fraction

agn_log_mbh

Fixed(7.0)

Black hole mass log10(M/Msun)

agn_log_ledd

Fixed(-1.0)

Eddington ratio log10(L/L_Edd)

agn_tau_skirtor

Fixed(7.0)

SKIRTOR 9.7 um optical depth

agn_p_skirtor

Fixed(1.0)

SKIRTOR radial density gradient

agn_q_skirtor

Fixed(1.0)

SKIRTOR polar density gradient

agn_oa_skirtor

Fixed(40)

SKIRTOR opening angle (degrees)

agn_cos_inc

Fixed(0.5)

Cosine of inclination (0=edge-on)

Radio (radio=True):

radio_q_ir

Fixed(2.64)

FIR-radio correlation (Bell 2003)

radio_alpha_sf

Fixed(0.8)

SF synchrotron spectral index

radio_loudness

Fixed(0.0)

AGN radio-loudness log10(L_5GHz/L_B)

radio_alpha_agn

Fixed(0.7)

AGN radio spectral index

X-ray (xray=True):

xray_gamma_agn

Fixed(1.8)

AGN X-ray photon index

xray_alpha_ox

Fixed(-1.4)

UV-to-X-ray slope

Evolving metallicity (evolving_metallicity=True):

met_logzsol_0

Uniform(-2, 0.2)

Initial metallicity (oldest stars)

met_logzsol_final

Uniform(-2, 0.2)

Final metallicity (present-day)

Examples

Minimal parametric model:

spec = Parameters(
    mean_sfh_type="dpl",
    sfh_dpl_alpha=Uniform(0.5, 3.0),
    sfh_dpl_beta=Uniform(0.5, 3.0),
    sfh_dpl_tau_gyr=Uniform(0.5, 13.0),
    sfh_dpl_log_peak_sfr=Uniform(-1.0, 2.5),
    met_logzsol=Uniform(-2.0, 0.5),
    dust_tau_bc=Uniform(0.0, 2.0),
    dust_tau_diff=Uniform(0.0, 2.0),
    redshift=Fixed(0.1),
)

Full model with all physics:

spec = Parameters(
    mean_sfh_type=["dpl", "field"],
    n_grid=64,
    # Dust attenuation
    dust_law_bc="kriek_conroy",
    dust_f_obscuration=Uniform(0.0, 0.5),
    dust_bump_strength=Uniform(0.0, 5.0),
    # Dust emission (DL07 tabulated templates)
    dust_emission="dl07_tabulated",
    dl07_grid_path="data/dl07_templates.h5",
    dust_umin=Uniform(0.1, 25.0),
    # Nebular (Cue neural emulator)
    nebular_cue=True,
    neb_logU=Uniform(-4.0, -1.0),
    neb_fesc_lya=Uniform(0.0, 1.0),
    # AGN (qsogen empirical quasar)
    agn_model="qsogen",
    agn_log_lbol=Uniform(40.0, 46.0),
    # IGM
    apply_igm=True,
    # Radio + X-ray
    radio=True,
    xray=True,
    # Evolving metallicity
    evolving_metallicity=True,
    met_logzsol_0=Uniform(-2.0, 0.2),
    met_logzsol_final=Uniform(-2.0, 0.2),
    met_alpha_fe=Uniform(-0.2, 0.6),
)
property all_params: list[str]

All parameter names (sorted, excludes settings).

Returns:

Sorted list of all fittable and fixed parameters (free_params + fixed_params).

Return type:

list[str]

property fixed_params: list[str]

Names of fixed parameters.

Returns:

Sorted list of parameter names with constant values (not varied during sampling or inference).

Return type:

list[str]

fixed_value(name: str) float | str | None[source]

Get the fixed value of a parameter.

Parameters:

name (str) – Parameter name (e.g., "redshift").

Returns:

The fixed value. Numeric values are returned as float; string-valued enums as str; None if the parameter is not fixed or if the fixed value is None.

Return type:

float | str | None

Raises:
  • KeyError – If the parameter name is not in the specification.

  • ValueError – If the parameter is not fixed.

Examples

>>> from tengri import Parameters, Fixed
>>> spec = Parameters(redshift=Fixed(0.1))
>>> spec.fixed_value("redshift")
0.1
property free_params: list[str]

Names of free (non-fixed) parameters.

Returns:

Sorted list of parameter names that are not fixed (vary during sampling and inference).

Return type:

list[str]

get_distribution(name: str) Distribution[source]

Get the prior distribution object for a parameter.

Parameters:

name (str) – Parameter name.

Returns:

The prior distribution object (Uniform, Gaussian, LogUniform, etc.) or Fixed for non-free parameters.

Return type:

Distribution

Raises:

KeyError – If parameter name is not valid for this model configuration.

Notes

For fixed parameters, the returned Distribution has is_fixed=True. For free parameters, the returned Distribution is one of Uniform, Gaussian, LogUniform, LogNormal, StudentT, or other prior types.

Examples

>>> from tengri import Parameters, Uniform
>>> spec = Parameters(
...     dust_tau_bc=Uniform(0, 4),
...     redshift=0.1,
... )
>>> prior = spec.get_distribution("dust_tau_bc")
>>> print(prior)
Uniform(0, 4)
get_fixed_values() dict[str, float][source]

Extract all numeric fixed parameter values as a dict.

Fixed (non-free) parameters are constants that do not vary during inference. This method returns only the numeric ones; categorical Fixed parameters (strings) are excluded.

Parameters:

None

Returns:

Mapping of numeric fixed parameter names to their constant values. String-valued Fixed parameters are not included because they cannot be represented as float.

Return type:

dict[str, float]

Notes

This is useful for freezing parameters before optimization, or for passing to upstream code that requires a flat parameter vector.

Examples

>>> from tengri import Parameters
>>> spec = Parameters(
...     redshift=0.1,
...     dust_tau_bc=(0, 4),
...     eline_broad="broad",  # String-valued
... )
>>> fixed = spec.get_fixed_values()
>>> print(fixed)
{'redshift': 0.1}
is_fixed(name: str) bool[source]

Check whether a parameter is fixed (non-free).

Parameters:

name (str) – Parameter name (e.g., "redshift", "met_logzsol").

Returns:

True if the parameter is fixed; False if free.

Return type:

bool

Raises:

KeyError – If the parameter name is not in the specification.

Examples

>>> from tengri import Parameters, Uniform, Fixed
>>> spec = Parameters(
...     redshift=Fixed(0.1),
...     dust_tau_bc=Uniform(0, 4),
... )
>>> spec.is_fixed("redshift")
True
>>> spec.is_fixed("dust_tau_bc")
False
property mean_sfh_type: list[str]

SFH model type(s) as a list of strings.

Returns:

Normalized, sorted list of active SFH model names (e.g., [‘dpl’, ‘field’]).

Return type:

list[str]

merge_observation_params(**extra_params: Distribution) Parameters[source]

Return a copy augmented with extra observation-level parameters.

Used by inference to inject emission-line amplitude parameters so they flow through bounds, prior penalty loops, and summary output without requiring special-casing in downstream code.

Parameters:

**extra_params (Distribution) – Mapping of parameter name → Distribution to add (e.g., eline_EW_Halpha=Uniform(0, 1000)).

Returns:

New Parameters instance with extra_params included in free_params. The original instance is not modified (immutable pattern).

Return type:

Parameters

Notes

This is distinct from with_params() in that all extra parameters are unconditionally added, whereas with_params() respects user-provided settings.

Examples

>>> from tengri import Parameters, Uniform
>>> spec = Parameters(redshift=0.1)
>>> spec_aug = spec.merge_observation_params(
...     eline_EW_Halpha=Uniform(0, 1000),
...     eline_EW_OIII=Uniform(0, 500),
... )
>>> print(spec_aug.n_free - spec.n_free)
2
property mirrors: dict[str, str]

{target_name: source_name}.

Returns:

Mapping of tied parameter names to their source parameters. When resolved, the target takes the value of the source.

Return type:

dict[str, str]

Type:

Parameter mirrors

property n_free: int

Number of free parameters (excludes sfh_field_xi).

Returns:

Count of all non-fixed parameters available for inference.

Return type:

int

property n_grid: int

GP grid size (only relevant when stochastic=True).

Returns:

Number of latent dimensions for stochastic SFH field.

Return type:

int

resolve_mirrors(params: dict) dict[source]

Copy mirrored parameter values from source to target.

For each mirror target source, copies the sampled source value to the target parameter. Used after sampling to ensure tied parameters have identical values. Returns a new dict (immutable pattern).

Parameters:

params (dict[str, ndarray]) – Parameter name → sampled value. Must include all source parameters.

Returns:

New dict with mirrored values filled in. For each target, the sampled value of the source parameter is assigned. Non-mirrored parameters are unchanged.

Return type:

dict[str, ndarray]

Notes

Parameter tying: Mirrors are specified in __init__ by passing a source parameter name as a string instead of a distribution:

Parameters(
    neb_logZ_gas="met_logzsol",  # Gas Z tied to stellar Z
    ...
)

This is more elegant than using Fixed(0) + manual post-hoc copying.

Examples

>>> from tengri import Parameters
>>> spec = Parameters(
...     met_logzsol=(-2, 0.5),
...     neb_logZ_gas="met_logzsol",  # Mirror: neb → met
... )
>>> params = {"met_logzsol": -0.3, "neb_logZ_gas": 0.0}
>>> resolved = spec.resolve_mirrors(params)
>>> print(resolved["neb_logZ_gas"])
-0.3
sample(key: Array) dict[str, Array][source]

Draw one random sample from all parameter prior distributions.

Samples all free parameters from their priors, returns fixed parameters at their fixed values, and (if stochastic) generates the latent field ξ ~ N(0,I). Mirrors are resolved (target ← source value).

Parameters:

key (jax.Array (PRNGKey)) – Random key for sampling.

Returns:

Parameter name → sampled value. Free parameters are sampled from their prior distributions. Fixed parameters return their constant value (as float or string). If stochastic, sfh_field_xi is an array of shape (n_grid,). Dictionary is immutable-ready (no direct mutation of values).

Return type:

dict[str, ndarray]

Notes

JIT-compatible: yes — safe to call inside jax.jit() on the key and parameters only (not on branching logic).

Stochastic SFH: When the model includes a GP field, sfh_field_xi is an independent N(0,1) vector of length n_grid. The SED model uses this to generate the stochastic log-SFR perturbations.

Examples

>>> import jax.random
>>> from tengri import Parameters, Uniform
>>> spec = Parameters(
...     sfh_dpl_alpha=Uniform(0.5, 3.0),
...     sfh_dpl_beta=Uniform(0.5, 3.0),
...     redshift=0.1,
... )
>>> key = jax.random.PRNGKey(42)
>>> samples = spec.sample(key)
>>> print(sorted(samples.keys()))
['redshift', 'sfh_dpl_alpha', 'sfh_dpl_beta']
sample_batch(key: Array, n: int) dict[str, Array][source]

Draw n random samples from all parameter prior distributions.

Vectorized sampling via jax.vmap(). Each parameter becomes a batch of n independent samples.

Parameters:
  • key (jax.Array (PRNGKey)) – Random key for sampling.

  • n (int) – Number of independent samples to draw.

Returns:

Parameter name → array of samples. Each entry has shape: - (n,) for scalar parameters - (n, n_grid) for sfh_field_xi (stochastic SFH only)

Return type:

dict[str, ndarray]

Notes

JIT-compatible: yes. The function is implemented via jax.vmap() applied to the single-sample method.

Memory: For n=1000, a 20-parameter model, and n_grid=64 (stochastic), the output dict occupies roughly 100 KB.

Examples

>>> import jax.random
>>> from tengri import Parameters, Uniform
>>> spec = Parameters(
...     sfh_dpl_alpha=Uniform(0.5, 3.0),
...     dust_tau_bc=Uniform(0, 4),
... )
>>> key = jax.random.PRNGKey(0)
>>> batch = spec.sample_batch(key, n=100)
>>> print(batch["sfh_dpl_alpha"].shape)
(100,)
property stochastic: bool

Whether the model includes a GP field component.

Returns:

True if ‘field’ is in mean_sfh_type, False otherwise.

Return type:

bool

summary() None[source]

Print a human-readable summary of the model configuration.

Displays SFH type, enabled components (nebular, dust, AGN, etc.), dimensionality, and a table of all parameters grouped by category (free first, then fixed). Useful for printing model status before fitting.

Use summary_str() if you need the underlying string (e.g. for logging) — summary() itself prints and returns None, matching the rest of the discovery API (tengri.summary(), tengri.help(), etc.).

Parameters:

None

Returns:

Output is printed to stdout.

Return type:

None

Notes

Output includes: - SFH type and composition - Dimensions (n free, latent ξ, mirrored, fixed) - Enabled optional modules (nebular, dust_emission, AGN, etc.) - Tabular list of parameters with their distributions/values

Examples

>>> from tengri import Parameters, Uniform
>>> spec = Parameters(
...     mean_sfh_type="dpl",
...     sfh_dpl_alpha=Uniform(0.5, 3),
...     dust_tau_bc=Uniform(0, 4),
... )
>>> print(spec.summary())
Parameters  SFH: dpl
────────────────────────────────────────────────────────────
  Dimensions:  3 free + 6 fixed
  Modules:     none
────────────────────────────────────────────────────────────
Free parameters:
  sfh_dpl_alpha            Uniform(0.5, 3)
  sfh_dpl_beta             Uniform(0.5, 3)
  ...
summary_str() str[source]

Return the summary as a string (e.g. for logging or tests).

property valid_param_names: frozenset

Set of valid parameter names for this model configuration.

Returns:

Immutable set of all parameter names allowed for this configuration, excluding settings and model configuration keys.

Return type:

frozenset

validate(params: dict[str, Array]) None[source]

Check that all parameter values respect their distribution bounds.

Useful before inference or after optimization to ensure no parameter has drifted outside its valid range.

Parameters:

params (dict[str, ndarray or float or str]) – Parameter name → value (sampled or optimized).

Returns:

Returns nothing. Raises an exception if validation fails.

Return type:

None

Raises:

ValueError – If any parameter is outside its bounds. Fixed parameters are always valid.

Notes

Missing parameters (not in dict) are silently ignored — this allows checking partial parameter sets.

Examples

>>> from tengri import Parameters, Uniform
>>> spec = Parameters(dust_tau_bc=Uniform(0, 4))
>>> params_valid = {"dust_tau_bc": 2.0}
>>> spec.validate(params_valid)  # OK
>>> params_bad = {"dust_tau_bc": 5.0}
>>> spec.validate(params_bad)  # Raises ValueError
with_params(**kwargs) Parameters[source]

Return a new Parameters with additional parameters merged in.

Creates an independent copy of this Parameters with extra parameters added (usually observation-level parameters like calibration or noise). User-defined parameters take precedence — if a name already exists in this spec, the new value is silently skipped (user intent is preserved).

Typically used internally by SEDModel to auto-inject noise and calibration parameters into the specification.

Parameters:

**kwargs – Parameter name → Distribution (or scalar/tuple shorthand). Only params not already explicitly provided by the user are added.

Returns:

New instance with merged parameters. The original is not modified (immutable pattern).

Return type:

Parameters

Notes

Immutability: The original Parameters object is never modified. A new Parameters instance is created via copy.copy(), with internal mutable structures (dicts) replaced with copies.

Parameter priority: User-provided parameters (via __init__) take absolute precedence. Auto-merged parameters are added only if their name is not in the user-provided set.

Examples

>>> from tengri import Parameters, Uniform
>>> spec = Parameters(redshift=0.1)
>>> # Merge in observation-level calibration parameters
>>> spec_aug = spec.with_params(
...     cal_offset_aper=Uniform(-0.1, 0.1),
...     noise_frac=Uniform(0, 0.05),
... )
>>> print(set(spec_aug.all_params) - set(spec.all_params))
{'cal_offset_aper', 'noise_frac'}

Prediction

class tengri.Prediction(model, params)[source]

Bases: object

Lazy prediction object with on-demand computation of derived quantities.

Created via model.predict(params). Properties are computed on first access and cached. The cache is shared across all property groups (sfh, sed, lines, radio, xray, ionizing), so related quantities share the expensive intermediates.

Parameters:
  • model (SEDModel) – The tengri SEDModel instance.

  • params (dict) – Parameter values (public names).

sfh

Star formation history derived quantities. Lazy accessor.

Type:

SFHProperties

sed

Spectral energy distribution derived quantities. Lazy accessor.

Type:

SEDProperties

lines

Emission line luminosities and diagnostic ratios. Lazy accessor.

Type:

LineProperties

radio

Radio-derived quantities from empirical relations. Lazy accessor.

Type:

RadioProperties

xray

X-ray derived quantities from empirical relations. Lazy accessor.

Type:

XRayProperties

ionizing

Ionizing photon budget quantities. Lazy accessor.

Type:

IonizingProperties

Returns:

Lazy prediction object with cached computed quantities.

Return type:

Prediction

Notes

This class is NOT JIT-compatible due to Python-level caching. For batch computations over many parameter sets (MCMC chains, mock catalogs), use the JIT-compatible methods SEDModel.predict_sfh_quantities(), SEDModel.predict_sed_quantities(), etc. instead. Those return JAX pytrees (SFHQuantities, SEDQuantities, DerivedQuantities) suitable for jax.vmap(), jax.jit(), and jax.grad().

Examples

Two equivalent ways to access derived quantities:

>>> pred = model.predict(params)
>>> pred.stellar_mass  # flat shortcut
>>> pred.sfh.stellar_mass  # grouped form (same value)
>>> pred.dn4000  # flat
>>> pred.sed.dn4000  # grouped
>>> pred.halpha  # flat
>>> pred.lines.halpha  # grouped

The grouped form (pred.sfh, pred.sed, pred.lines, pred.radio, pred.xray, pred.ionizing) exposes every derived quantity. The top-level shortcuts cover the most-used quantities for quick access; for less-common ones use the grouped form. Both share the same lazy cache, so accessing a quantity by either route triggers computation only once.

Accessing the full SED or photometry:

>>> pred.sed_array  # shape (n_wave,)
>>> pred.photometry  # shape (n_filters,)

For batch computation, use JIT-compatible methods instead:

>>> sfh_batch = jax.vmap(model.predict_sfh_quantities)(params_batch)
property balmer_break

Balmer break flux ratio. Same as pred.sed.balmer_break.

property balmer_decrement

Hα/Hβ flux ratio. Same as pred.lines.balmer_decrement.

property dn4000

D_n(4000) break ratio. Same as pred.sed.dn4000.

property fuv_flux

FUV flux at 1500 Å [erg/s/cm²]. Same as pred.sed.fuv_flux.

property fuv_flux_intrinsic

Dust-free FUV flux. Same as pred.sed.fuv_flux_intrinsic.

property halpha

Hα 6564 Å luminosity [erg/s]. Same as pred.lines.halpha.

property hbeta

Hβ 4862 Å luminosity [erg/s]. Same as pred.lines.hbeta.

ionizing
property irx

Infrared excess L_TIR / L_UV(1600 Å). Same as pred.sed.irx.

property l_bol

Bolometric luminosity [L☉]. Same as pred.sed.l_bol.

property l_dust_absorbed

Dust-absorbed luminosity [L☉]. Same as pred.sed.l_dust_absorbed.

property l_tir

Total infrared (8–1000 μm) luminosity [L☉]. Same as pred.sed.l_tir.

lines
property luminosity_weighted_age_gyr

Luminosity-weighted age [Gyr]. Same as pred.sed.luminosity_weighted_age_gyr.

Both pred.sfh and pred.sed define this; the top-level shortcut forwards to the SED version (canonical “luminosity-weighted” using the attenuated stellar SED).

property luminosity_weighted_metallicity

Luminosity-weighted log₁₀(Z/Z☉).

Same as pred.sed.luminosity_weighted_metallicity.

property m_uv

Absolute magnitude at 1500 Å. Same as pred.sed.m_uv.

property mass_weighted_age_gyr

Mass-weighted stellar age [Gyr]. Same as pred.sfh.mass_weighted_age_gyr.

property mass_weighted_metallicity

Mass-weighted log₁₀(Z/Z☉). Same as pred.sfh.mass_weighted_metallicity.

property nuv_flux

NUV flux at 2300 Å [erg/s/cm²]. Same as pred.sed.nuv_flux.

property nuv_flux_intrinsic

Dust-free NUV flux. Same as pred.sed.nuv_flux_intrinsic.

property oiii_5007

[O III] 5007 Å luminosity [erg/s]. Same as pred.lines.oiii_5007.

property photometry

Observed photometric flux densities.

Returns:

Photometry at the filters defined in the SEDModel [erg/s/cm²/Hz].

Return type:

ndarray, shape (n_filters,)

Notes

JIT-compatible: no — Python property accessor. Use in postprocessing, not inside jax.jit().

Examples

pred = model.predict(params)
phot = pred.photometry  # ndarray, shape (n_filters,)
print(phot.shape)  # e.g. (8,) for 8 photometric bands
property q_h

Total ionizing photon production rate [s⁻¹]. Same as pred.ionizing.q_h.

radio
property rest_uv_color

Rest-frame UV color (f_1500 − f_2300). Same as pred.sed.rest_uv_color.

sed
property sed_array

Full rest-frame SED array.

Returns:

Total spectral energy distribution [erg/s/Hz].

Return type:

ndarray, shape (n_wave,)

Notes

JIT-compatible: no — Python property accessor. Use in postprocessing, not inside jax.jit().

sfh
property sfr_100myr

SFR averaged over last 100 Myr [M☉/yr]. Same as pred.sfh.sfr_100myr.

property sfr_10myr

SFR averaged over last 10 Myr [M☉/yr]. Same as pred.sfh.sfr_10myr.

property ssfr

Specific SFR [yr⁻¹]. Same as pred.sfh.ssfr.

property stellar_mass

Total stellar mass formed [M☉]. Same as pred.sfh.stellar_mass.

property stellar_mass_surviving

Surviving stellar + remnant mass [M☉]. Same as pred.sfh.stellar_mass_surviving.

property uv_slope_beta

UV slope β in f_λ ∝ λ^β. Same as pred.sed.uv_slope_beta.

property xi_ion

Ionizing photon production efficiency [Hz·erg⁻¹]. Same as pred.ionizing.xi_ion.

xray

SFHQuantities

class tengri.SFHQuantities(stellar_mass: Array, stellar_mass_surviving: Array, sfr_100myr: Array, sfr_10myr: Array, ssfr: Array, mass_weighted_age_gyr: Array, mass_weighted_metallicity: Array)[source]

Bases: NamedTuple

Derived quantities from the star formation history.

All fields are JAX arrays (scalars). This is a proper JAX pytree, so it works with jax.jit, jax.vmap, and jax.grad.

stellar_mass

Total formed stellar mass [Msun].

Type:

jnp.ndarray

stellar_mass_surviving

Surviving mass in living stars + remnants [Msun]. Returns NaN if the mass-remaining table was not loaded.

Type:

jnp.ndarray

sfr_100myr

Star formation rate averaged over the last 100 Myr [Msun/yr].

Type:

jnp.ndarray

sfr_10myr

Star formation rate averaged over the last 10 Myr [Msun/yr].

Type:

jnp.ndarray

ssfr

Specific star formation rate SFR/M* [1/yr].

Type:

jnp.ndarray

mass_weighted_age_gyr

Mass-weighted stellar age [Gyr].

Type:

jnp.ndarray

mass_weighted_metallicity

Mass-weighted metallicity log10(Z), evolving-Z aware.

Type:

jnp.ndarray

Returns:

Notes

JAX-compatible array container. All fields are JAX arrays compatible with jax.jit and jax.vmap. Returned by SEDModel.predict_sfh_quantities() and Prediction.sfh when accessed.

Examples

>>> import jax.numpy as jnp
>>> from tengri import SFHQuantities
>>> q = SFHQuantities(
...     stellar_mass=jnp.array(1e10),
...     stellar_mass_surviving=jnp.array(6e9),
...     sfr_100myr=jnp.array(5.0),
...     sfr_10myr=jnp.array(8.0),
...     ssfr=jnp.array(5e-10),
...     mass_weighted_age_gyr=jnp.array(3.5),
...     mass_weighted_metallicity=jnp.array(-0.5),
... )
>>> float(q.stellar_mass)
10000000000.0
>>> "stellar_mass" in q._fields and "sfr_100myr" in q._fields
True
mass_weighted_age_gyr: Array

Alias for field number 5

mass_weighted_metallicity: Array

Alias for field number 6

sfr_100myr: Array

Alias for field number 2

sfr_10myr: Array

Alias for field number 3

ssfr: Array

Alias for field number 4

stellar_mass: Array

Alias for field number 0

stellar_mass_surviving: Array

Alias for field number 1

SEDQuantities

class tengri.SEDQuantities(l_bol: Array, l_tir: Array, l_dust_absorbed: Array, irx: Array, uv_slope_beta: Array, dn4000: Array, balmer_break: Array, m_uv: Array, fuv_flux: Array, nuv_flux: Array, fuv_flux_intrinsic: Array, nuv_flux_intrinsic: Array, rest_uv_color: Array, luminosity_weighted_age_gyr: Array, luminosity_weighted_metallicity: Array)[source]

Bases: NamedTuple

Derived quantities from the spectral energy distribution.

All fields are JAX arrays. Proper JAX pytree for jit/vmap.

l_bol

Bolometric luminosity [Lsun].

Type:

jnp.ndarray

l_tir

Total infrared luminosity 8–1000 μm [Lsun].

Type:

jnp.ndarray

l_dust_absorbed

Dust-absorbed luminosity [Lsun]. Returns NaN if no intrinsic SED.

Type:

jnp.ndarray

irx

Infrared excess log10(L_TIR / νLν_1600) [dimensionless].

Type:

jnp.ndarray

uv_slope_beta

UV spectral slope β in range 1250–2600 Å [dimensionless].

Type:

jnp.ndarray

dn4000

Narrow 4000 Å break, Balogh et al. 1999 [dimensionless].

Type:

jnp.ndarray

balmer_break

Modified Balmer break, Wang et al. 2024 [dimensionless].

Type:

jnp.ndarray

m_uv

Absolute UV magnitude at rest-frame 1500 Å [AB].

Type:

jnp.ndarray

fuv_flux

Mean flux density in FUV 1000–1700 Å [erg/s/Hz].

Type:

jnp.ndarray

nuv_flux

Mean flux density in NUV 1700–3200 Å [erg/s/Hz].

Type:

jnp.ndarray

fuv_flux_intrinsic

Dust-free FUV flux [erg/s/Hz]. Returns NaN if no intrinsic SED.

Type:

jnp.ndarray

nuv_flux_intrinsic

Dust-free NUV flux [erg/s/Hz]. Returns NaN if no intrinsic SED.

Type:

jnp.ndarray

rest_uv_color

Rest-frame U-V color [AB magnitudes].

Type:

jnp.ndarray

luminosity_weighted_age_gyr

Luminosity-weighted age [Gyr].

Type:

jnp.ndarray

luminosity_weighted_metallicity

Luminosity-weighted metallicity log10(Z).

Type:

jnp.ndarray

Returns:

Notes

JAX-compatible array container. All fields are JAX arrays compatible with jax.jit and jax.vmap. Returned by SEDModel.predict_sed_quantities() and Prediction.sed when accessed.

Examples

Access via Prediction.sed after calling SEDModel.predict():

pred = model.predict(params)
sed = pred.sed  # SEDQuantities NamedTuple
print(float(sed.l_bol))  # bolometric luminosity [Lsun]
print(float(sed.dn4000))  # 4000 Å break strength
print(float(sed.uv_slope_beta))  # UV slope beta
balmer_break: Array

Alias for field number 6

dn4000: Array

Alias for field number 5

fuv_flux: Array

Alias for field number 8

fuv_flux_intrinsic: Array

Alias for field number 10

irx: Array

Alias for field number 3

l_bol: Array

Alias for field number 0

l_dust_absorbed: Array

Alias for field number 2

l_tir: Array

Alias for field number 1

luminosity_weighted_age_gyr: Array

Alias for field number 13

luminosity_weighted_metallicity: Array

Alias for field number 14

m_uv: Array

Alias for field number 7

nuv_flux: Array

Alias for field number 9

nuv_flux_intrinsic: Array

Alias for field number 11

rest_uv_color: Array

Alias for field number 12

uv_slope_beta: Array

Alias for field number 4

DerivedQuantities

class tengri.DerivedQuantities(sfh: SFHQuantities, sed: SEDQuantities)[source]

Bases: NamedTuple

All derived physical quantities (convenience container).

Returned by model.predict_derived().

sfh

Star formation history derived quantities.

Type:

SFHQuantities

sed

Spectral energy distribution derived quantities.

Type:

SEDQuantities

Returns:

Notes

JAX-compatible array container combining SFHQuantities and SEDQuantities. Compatible with jax.jit and jax.vmap. Returned by SEDModel.predict_derived().

Examples

from tengri import DerivedQuantities

derived = model.predict_derived(params)
print(float(derived.sfh.stellar_mass))  # [Msun]
print(float(derived.sed.dn4000))  # 4000 Å break
print(float(derived.sed.uv_slope_beta))  # UV slope β
sed: SEDQuantities

Alias for field number 1

sfh: SFHQuantities

Alias for field number 0

EmissionLines

class tengri.EmissionLines(lya: Array, civ_1549: Array, oii: Array, hbeta: Array, oiii_4959: Array, oiii_5007: Array, nii_6548: Array, halpha: Array, nii_6584: Array, sii_6717: Array, sii_6731: Array)[source]

Bases: NamedTuple

Key emission line luminosities.

NaN for all fields when no nebular model is active. For doublets ([OII], C IV), the luminosities of both components are summed.

lya

Lyman-alpha at 1216 Å [Lsun].

Type:

jnp.ndarray

civ_1549

C IV doublet 1548+1551 Å, summed [Lsun].

Type:

jnp.ndarray

oii

[OII] doublet 3726+3729 Å, summed [Lsun].

Type:

jnp.ndarray

hbeta

H-beta at 4861 Å [Lsun].

Type:

jnp.ndarray

oiii_4959

[OIII] at 4959 Å [Lsun].

Type:

jnp.ndarray

oiii_5007

[OIII] at 5007 Å [Lsun].

Type:

jnp.ndarray

nii_6548

[NII] at 6548 Å [Lsun].

Type:

jnp.ndarray

halpha

H-alpha at 6563 Å [Lsun].

Type:

jnp.ndarray

nii_6584

[NII] at 6584 Å [Lsun].

Type:

jnp.ndarray

sii_6717

[SII] at 6717 Å [Lsun].

Type:

jnp.ndarray

sii_6731

[SII] at 6731 Å [Lsun].

Type:

jnp.ndarray

Returns:

  • This is a NamedTuple (JAX pytree) returned by

  • SEDModel.predict_emission_lines().

Notes

JAX-compatible array container. All fields are JAX arrays compatible with jax.jit and jax.vmap. Returned by Prediction.lines when accessed. All fields return NaN if no nebular model is active in the SEDModel.

Examples

Access via Prediction.lines after calling SEDModel.predict():

pred = model.predict(params)
lines = pred.lines  # EmissionLines NamedTuple
print(float(lines.halpha))  # H-alpha luminosity [Lsun]
print(float(lines.oiii_5007))  # [OIII] 5007 Å luminosity [Lsun]
# BPT diagram
bpt_x = float(lines.nii_6584 / lines.halpha)
bpt_y = float(lines.oiii_5007 / lines.hbeta)
civ_1549: Array

Alias for field number 1

halpha: Array

Alias for field number 7

hbeta: Array

Alias for field number 3

lya: Array

Alias for field number 0

nii_6548: Array

Alias for field number 6

nii_6584: Array

Alias for field number 8

oii: Array

Alias for field number 2

oiii_4959: Array

Alias for field number 4

oiii_5007: Array

Alias for field number 5

sii_6717: Array

Alias for field number 9

sii_6731: Array

Alias for field number 10

MockData

class tengri.MockData(flux_true: jnp.ndarray, flux_obs: jnp.ndarray, noise: jnp.ndarray, params: dict)[source]

Bases: NamedTuple

Container for mock galaxy observations.

Parameters:
  • flux_true (ndarray, shape (n_filters,)) – Noiseless model photometry. [erg/s/cm²/Hz]

  • flux_obs (ndarray, shape (n_filters,)) – Noisy photometry with Gaussian scatter added. [erg/s/cm²/Hz]

  • noise (ndarray, shape (n_filters,)) – 1-sigma photometric uncertainties used to draw the noise. [erg/s/cm²/Hz]

  • params (dict) – Input physical parameters used to generate the mock.

Returns:

Named tuple containing noiseless and noisy photometry.

Return type:

MockData

flux_true

Noiseless model photometry. [erg/s/cm²/Hz]

Type:

ndarray, shape (n_filters,)

flux_obs

Noisy photometry with Gaussian scatter added. [erg/s/cm²/Hz]

Type:

ndarray, shape (n_filters,)

noise

1-sigma photometric uncertainties used to draw the noise. [erg/s/cm²/Hz]

Type:

ndarray, shape (n_filters,)

params

Input physical parameters used to generate the mock.

Type:

dict

Notes

JIT-compatible: yes — NamedTuple is a JAX pytree.

Immutable: All fields are read-only by design. To create a modified version, use the _replace() method inherited from NamedTuple.

Examples

from tengri import SEDModel, Parameters, Uniform
model = SEDModel(spec, ssp_data, filter_names=["hst_acs_f606w", "hst_acs_f814w"])
params = {"sfh_dpl_alpha": 2.0, "sfh_dpl_beta": 1.5, ...}
mock = model.make_mock(params, snr=20.0)
mock.flux_obs.shape    # (n_filters,)
mock.plot()            # matplotlib Figure
flux_obs: Array

Alias for field number 1

flux_true: Array

Alias for field number 0

noise: Array

Alias for field number 2

params: dict

Alias for field number 3

plot(filter_names=None, ax=None)[source]

Plot mock photometry with errorbars.

Parameters:
  • filter_names (list of str, optional) – Filter labels for the x-axis. Falls back to integer indices if None.

  • ax (matplotlib Axes, optional) – Axes to plot on. Creates new figure if None.

Returns:

fig – Matplotlib figure with photometry plotted as error bars (observed with noise) and markers (true noiseless).

Return type:

matplotlib Figure

Notes

JIT-compatible: no — uses matplotlib for visualization.

generate_mock

tengri.generate_mock(model, params, key=None, snr=20.0)[source]

Generate mock galaxy photometry with optional Gaussian noise.

Computes noiseless predicted photometry, then optionally realizes noise at a specified signal-to-noise ratio. Useful for testing data pipelines, validating inference, and parameter recovery studies.

Parameters:
  • model (object) – Any object with a predict_photometry(params) method that returns an array of flux densities.

  • params (dict[str, ndarray]) – Model parameter values (typically sampled or optimized).

  • key (jax.Array (PRNGKey), optional) – Random key for noise realization. If None, only noiseless photometry is returned (no flux_obs key in output).

  • snr (float, optional) – Signal-to-noise ratio (flux_true / noise_std). Default: 20.0.

Returns:

Mock observation data with keys: - flux_true : noiseless predicted photometry [erg/s/cm²/Hz] - noise : noise standard deviation per band [erg/s/cm²/Hz] - params : the input parameter values - flux_obs : observed (noisy) photometry (only if key is not None)

Return type:

dict

Notes

Noise model: Assumes Gaussian noise with σ = flux_true / SNR. This is appropriate for photon-limited observations.

Examples

>>> import jax.random
>>> from tengri.forward import SEDModel
>>> model = SEDModel(...)
>>> params = {'redshift': 0.1, ...}
>>> key = jax.random.PRNGKey(42)
>>> mock = generate_mock(model, params, key=key, snr=10.0)
>>> print(f"True flux shape: {mock['flux_true'].shape}")
>>> print(f"Obs. flux shape: {mock['flux_obs'].shape}")

Observation

class tengri.Observation(photometry: Photometry | None = None, spectroscopy: Spectroscopy | None = None, noise: NoiseModel | None = None, line_fluxes: LineFluxData | None = None, spectral_indices: SpectralIndexData | None = None)[source]

Bases: object

Unified observation configuration.

Composes optional photometric, spectroscopic, and noise model configurations. At least one of photometry or spectroscopy must be provided.

Parameters:
  • photometry (Photometry or None) – Photometric filter configuration.

  • spectroscopy (Spectroscopy or None) – Spectroscopic instrument configuration.

  • noise (NoiseModel or None) – Noise model configuration (calibration floor, Student-t dof).

  • line_fluxes (LineFluxData or None) – Observed emission line fluxes for direct fitting. When provided, the likelihood includes an additive chi-squared term comparing model line luminosities against these fluxes.

Returns:

Validated observation container with at least one data modality.

Return type:

Observation

photometry

Photometric filter configuration.

Type:

Photometry or None

spectroscopy

Spectroscopic instrument configuration.

Type:

Spectroscopy or None

noise

Noise model configuration.

Type:

NoiseModel or None

line_fluxes

Observed emission line fluxes.

Type:

LineFluxData or None

spectral_indices

Observed spectral indices for fitting.

Type:

SpectralIndexData or None

Notes

A frozen, immutable dataclass that serves as a declarative container for all observation metadata. Never enters JAX-traced code; used solely for configuration dispatch to precomputation and inference steps. Inspired by Synthesizer’s Instrument pattern, adapted for tengri’s differentiable context.

Examples

Photometry-only:

obs = Observation(
    photometry=Photometry.from_names(["sdss_r", "sdss_i"]),
)

Joint photometry + spectroscopy:

obs = Observation(
    photometry=Photometry.from_names(["jwst_f200w", "jwst_f356w"]),
    spectroscopy=Spectroscopy.nirspec_prism(wave_obs),
    noise=NoiseModel(calibration_floor=Uniform(0.01, 0.15)),
)
property can_do_photometry: bool

Whether photometric filters are configured.

Returns:

True if photometry is configured.

Return type:

bool

Notes

Query method for capability checking. Safe to call even if photometry was not provided to this Observation.

property can_do_spectroscopy: bool

Whether a spectroscopic wavelength grid is configured.

Returns:

True if spectroscopy is configured.

Return type:

bool

Notes

Query method for capability checking. Safe to call even if spectroscopy was not provided to this Observation.

property data_type: str

Inferred data type string (photometry/spectroscopy/joint).

Returns:

One of "photometry", "spectroscopy", or "joint".

Return type:

str

Notes

Returns a string representation of the configured data types, useful for logging and dispatch logic.

get_all_params() dict[str, Distribution][source]

Collect all observation-driven parameters.

Merges calibration polynomial params from spectroscopy config and noise model params from noise config.

Returns:

Parameter name → Distribution mapping. Empty if no observation params are needed (e.g. photometry-only with no noise config).

Return type:

dict

Notes

This method is called by the inference engine to set up the prior structure. Observation parameters include calibration coefficients and noise model hyperparameters, but not SED or SFH params.

property has_line_fluxes: bool

Whether observed emission line fluxes are configured.

Returns:

True if line flux data is configured.

Return type:

bool

Notes

Query method for capability checking. Safe to call even if line fluxes were not provided to this Observation.

property has_spectral_indices: bool

Whether observed spectral indices are configured.

Returns:

True if spectral index data is configured.

Return type:

bool

Notes

Query method for capability checking. Safe to call even if spectral indices were not provided to this Observation.

property is_joint: bool

Whether both photometry and spectroscopy are configured.

Returns:

True if both photometry and spectroscopy are present.

Return type:

bool

Notes

Convenience predicate for detecting joint photometry+spectroscopy fitting (vs. photometry-only or spectroscopy-only).

line_fluxes: LineFluxData | None = None
property n_data: int

Total number of data points.

Returns:

Sum of all photometric, spectroscopic, line flux, and spectral index data points.

Return type:

int

Notes

Aggregates counts across all observation modalities. Useful for data dimensionality checks and prior/posterior shape validation.

property n_data_indices: int

Number of spectral index data points.

Returns:

Number of spectral indices, or 0 if no index data configured.

Return type:

int

Notes

Returns 0 safely if spectral index data is not configured; may be used in conditional logic without prior capability checks.

property n_data_lines: int

Number of emission line flux data points.

Returns:

Number of emission lines, or 0 if no line flux data configured.

Return type:

int

Notes

Returns 0 safely if line flux data is not configured; may be used in conditional logic without prior capability checks.

property n_data_phot: int

Number of photometric data points (filters).

Returns:

Number of filters, or 0 if no photometry configured.

Return type:

int

Notes

Returns 0 safely if photometry is not configured; may be used in conditional logic without prior capability checks.

property n_data_spec: int

Number of spectroscopic data points (pixels).

Returns:

Number of spectral pixels, or 0 if no spectroscopy configured.

Return type:

int

Notes

Returns 0 safely if spectroscopy is not configured; may be used in conditional logic without prior capability checks.

noise: NoiseModel | None = None
observe_photometry(sed_result, z: float, dl_cm: float) Array[source]

Project an observed-frame SED through photometric filters.

Parameters:
  • sed_result (SEDResult) – Observed-frame SED with wavelength and sed.

  • z (float) – Redshift.

  • dl_cm (float) – Luminosity distance [cm].

Returns:

Photometric fluxes in each filter [erg/s/Hz].

Return type:

jnp.ndarray, shape (n_filters,)

Notes

Requires photometry to be configured. Computes rest-frame wavelengths from the observed-frame input and integrates the SED against each filter transmission curve.

observe_spectrum(sed_result, z: float, dl_cm: float, sigma_v_kms: float = 0.0) Array[source]

Project an observed-frame SED onto spectroscopic pixel grid.

Parameters:
  • sed_result (SEDResult) – Observed-frame SED with wavelength and sed.

  • z (float) – Redshift.

  • dl_cm (float) – Luminosity distance [cm].

Returns:

Spectroscopic flux at each pixel [erg/s/Hz].

Return type:

jnp.ndarray, shape (n_pixels,)

Notes

Requires spectroscopy to be configured. Applies LSF convolution if a resolution profile is specified. Returns data ready for likelihood evaluation against observed spectra.

pack_data(phot: Array | None = None, spec: Array | None = None) Array[source]

Concatenate photometry and spectroscopy data in canonical order.

Validates array shapes against the observation configuration. Canonical order: [photometry, spectroscopy].

Parameters:
  • phot (array or None) – Photometric data, shape (n_filters,).

  • spec (array or None) – Spectroscopic data, shape (n_pixels,).

Returns:

Packed data array, shape (n_data,).

Return type:

jnp.ndarray

Raises:

ValueError – If array shapes don’t match the observation configuration.

Notes

Both arrays are optional but at least one must be provided and configured in the Observation. Useful for likelihood evaluation and parameter inference pipelines.

photometry: Photometry | None = None
spectral_indices: SpectralIndexData | None = None
spectroscopy: Spectroscopy | None = None
summary() str[source]

Return a human-readable summary of the observation.

Returns:

Multi-line summary string with filter counts, instrument config, noise settings, and total data point count.

Return type:

str

Notes

Used for logging and diagnostics. Does not execute any inference; purely informational output.

unpack_prediction(predicted: Array) dict[str, Array][source]

Split a concatenated prediction into photometry and spectroscopy.

Inverse of pack_data: reverses the concatenation to extract predictions for each observation modality.

Parameters:

predicted (array) – Packed prediction array, shape (n_data,).

Returns:

Keys are "photometry" and/or "spectroscopy", values are the corresponding sub-arrays.

Return type:

dict

Notes

Only keys corresponding to configured observation modalities will be present in the returned dictionary.

Photometry

class tengri.Photometry(filters: tuple[FilterCurve, ...], names: tuple[str, ...] = (), filter_waves: tuple[Array, ...] = (), filter_trans: tuple[Array, ...] = (), n_filters: int = 0)[source]

Bases: object

Photometric observation configuration — filter set, not the fluxes.

This class holds which bands the model should evaluate. Measured fluxes and uncertainties are passed separately to tengri.Fitter (data= and noise=).

Don’t call Photometry(...) directly — use the factory:

>>> phot = tengri.Photometry.from_names(["sdss_g", "sdss_r", "sdss_i"])

or browse the bundled set:

>>> bandset = tengri.list_filters(survey="SDSS").names()
>>> phot = tengri.Photometry.from_names(bandset)

Then the full fit pattern:

obs = tengri.Observation(photometry=phot) spec = tengri.Parameters(redshift=0.1, …) model = tengri.SEDModel(spec, ssp_data, observation=obs) fitter = tengri.Fitter(model,

data=measured_fluxes, # ← your fluxes go here noise=measured_errors) # ← your sigma here

posterior = fitter.run(“nuts”)

Parameters:
  • filters (tuple of FilterCurve) – Filter transmission curves.

  • names (tuple of str) – Human-readable filter names (e.g. ("sdss_r", "sdss_i")).

Returns:

Photometry instance with derived fields populated.

Return type:

Photometry

filters

Filter transmission curves.

Type:

tuple[FilterCurve, …]

names

Human-readable filter names.

Type:

tuple[str, …]

filter_waves

Wavelength arrays for each filter [Angstrom].

Type:

tuple[ndarray, …]

filter_trans

Transmission curves for each filter [dimensionless].

Type:

tuple[ndarray, …]

n_filters

Number of filters.

Type:

int

Notes

A frozen dataclass that encapsulates filter metadata and transmission curves. Provides factory methods (from_names, from_filter_set) for convenient construction. Precomputes derived fields (filter_waves, filter_trans, n_filters) at initialization for efficient SED projection.

Examples

>>> from tengri import Photometry
>>> phot = Photometry.from_names(["sdss_r", "sdss_i"])
>>> phot.n_filters
2
filter_trans: tuple[Array, ...] = ()
filter_waves: tuple[Array, ...] = ()
filters: tuple[FilterCurve, ...]
static from_filter_set(filter_set: tuple[list[Array], list[Array], list[FilterCurve]] | list[FilterCurve] | tuple) Photometry[source]

Create Photometry from existing filter data.

Accepts the 3-tuple returned by load_filter_set() or a list of FilterCurve objects.

Parameters:

filter_set (tuple | list) – Either a 3-tuple (filter_waves, filter_trans, filter_curves) from load_filter_set(), or a list/tuple of FilterCurve objects.

Returns:

Configured photometry with precomputed field (filter curves, wavelengths, transmissions, and filter count).

Return type:

Photometry

Notes

Flexible constructor that accepts pre-loaded filter data. Useful when filters are already loaded or constructed by external code.

Examples

>>> from tengri.observation.filters import load_filter_set
>>> waves, trans, curves = load_filter_set(["sdss_r", "sdss_i"])
>>> phot = Photometry.from_filter_set((waves, trans, curves))
static from_names(names: Sequence[str], cache_dir: str = 'data/filters') Photometry[source]

Create Photometry from filter registry short names.

Parameters:
  • names (sequence[str]) – Short names from FILTER_REGISTRY (e.g. "sdss_r", "jwst_f200w").

  • cache_dir (str, optional) – Directory for cached SVO filter files. Default: "data/filters".

Returns:

Configured photometry with loaded filter transmission curves. Filters are validated against the global registry.

Return type:

Photometry

Notes

Loads filter transmission curves from the SVO filter service or local cache. Filter names are validated against the registry; unrecognized names raise a KeyError.

Examples

>>> phot = Photometry.from_names(["sdss_u", "sdss_g", "sdss_r"])
>>> phot.n_filters
3
n_filters: int = 0
names: tuple[str, ...] = ()
summary() str[source]

Return a one-line summary of the photometry configuration.

Returns:

Filter count and comma-separated filter names.

Return type:

str

Notes

Provides concise string representation for logging and diagnostics.

Spectroscopy

class tengri.Spectroscopy(wave_obs: Array, resolution: float | Array | None = None, sigma_lib_kms: float = 70.0, lsf_n_bins: int = 16, calibration_order: int = 0, eline_prior_sigma: float = 100.0, eline_mode: str = 'off', eline_catalog: object | None = None, eline_prior_type: str = 'flat', eline_prior_width_dex: float = 0.3, eline_fix_doublets: bool = True, eline_broad: bool = False, eline_broad_fwhm_min_kms: float = 500.0, covariance: Array | None = None)[source]

Bases: object

Spectroscopic observation configuration.

Parameters:
  • wave_obs (jnp.ndarray) – Observed-frame wavelength grid [Angstrom], shape (n_pix,).

  • resolution (float, jnp.ndarray, or None) – Spectral resolution R = lambda / delta_lambda. Scalar for constant R, per-pixel array for wavelength-dependent, or None to skip LSF convolution. Default: None.

  • sigma_lib_kms (float) – SSP library velocity dispersion [km/s] to subtract in quadrature when applying the LSF. Default: 70.0 (MILES).

  • lsf_n_bins (int) – Number of bins for piecewise constant approximation of variable-R LSF convolution. Default: 16.

  • calibration_order (int) – Order of multiplicative Chebyshev calibration polynomial. 0 = no calibration (default). Order N adds N free params (cal_c1, …, cal_cN) with Gaussian(0, 0.1) priors.

  • eline_prior_sigma (float) – Prior width on emission line amplitudes for marginalization. Default: 100.0.

  • eline_mode (str) –

    Emission line fitting mode. One of:

    • "off": No emission line fitting (default).

    • "fixed": Lines from nebular model only.

    • "marginalized": Analytically marginalize line amplitudes (recommended for spectroscopic fitting).

    • "fitted": Line amplitudes as free MCMC parameters.

    Default: "off".

  • eline_catalog (LineList or None) – Line catalog. None falls back to LineList.default_13() for Use LineList.default_optical() for FastSpecFit parity. Default: None.

  • eline_prior_type (str) – Prior type for line amplitudes. One of "flat" (uninformative) or "cloudy" (CLOUDY-grid-interpolated). Default: "flat".

  • eline_prior_width_dex (float) – Prior scatter in dex for the "cloudy" prior. Default: 0.3.

  • eline_fix_doublets (bool) – Enforce atomic physics doublet ratios. Default: True.

  • eline_broad (bool) – Enable broad component for AGN candidate lines. Default: False.

  • eline_broad_fwhm_min_kms (float) – Minimum FWHM for the broad component [km/s]. Default: 500.0.

  • covariance (jnp.ndarray or None) – Full spectral covariance matrix, shape (n_pix, n_pix). When provided, the likelihood uses diff @ C^{-1} @ diff instead of per-pixel sum((diff/sigma)^2). The inverse is precomputed at construction time. Default: None (diagonal noise).

Returns:

Spectroscopy instance with covariance matrix inverted and metadata set.

Return type:

Spectroscopy

wave_obs

Observed-frame wavelength grid [Angstrom].

Type:

ndarray, shape (n_pix,)

resolution

Spectral resolution.

Type:

float, ndarray, or None

sigma_lib_kms

SSP library velocity dispersion [km/s].

Type:

float

lsf_n_bins

Number of LSF approximation bins.

Type:

int

calibration_order

Chebyshev polynomial order.

Type:

int

eline_prior_sigma

Emission line prior width.

Type:

float

eline_mode

Emission line fitting mode.

Type:

str

eline_catalog

Emission line catalog.

Type:

LineList or None

eline_prior_type

Prior type for line marginalization.

Type:

str

eline_prior_width_dex

Prior scatter [dex].

Type:

float

eline_fix_doublets

Whether to enforce doublet ratios.

Type:

bool

eline_broad

Whether broad AGN component is enabled.

Type:

bool

eline_broad_fwhm_min_kms

Minimum broad component FWHM [km/s].

Type:

float

covariance

Spectral covariance matrix.

Type:

ndarray, shape (n_pix, n_pix) or None

covariance_inv

Inverse covariance matrix (precomputed).

Type:

ndarray, shape (n_pix, n_pix) or None

Notes

A frozen dataclass that encapsulates spectroscopic instrument metadata, including wavelength grid, resolution profile, calibration strategy, and emission-line fitting configuration. Precomputes the inverse covariance matrix at initialization for efficient likelihood evaluation. Used by SEDModel to configure spectral prediction and by the inference engine to set up calibration priors.

Examples

>>> import jax.numpy as jnp
>>> from tengri import Spectroscopy
>>> wave = jnp.linspace(4000.0, 9000.0, 500)
>>> spec = Spectroscopy(wave_obs=wave, resolution=1000.0)
>>> spec.n_pixels
500
calibration_order: int = 0
static constant_r(wave_obs: Array, R: float, **kwargs) Spectroscopy[source]

Constant-resolution spectrograph.

Parameters:
  • wave_obs (ndarray, shape (n_pix,)) – Observed-frame wavelength grid [Angstrom].

  • R (float) – Spectral resolution R = lambda / delta_lambda (constant across all wavelengths, dimensionless).

  • **kwargs – Additional keyword arguments passed to Spectroscopy.

Returns:

Configured spectrograph with constant wavelength-independent resolution.

Return type:

Spectroscopy

Notes

Convenient factory for instruments with wavelength-independent spectral resolution, such as low-resolution JWST NIRSpec PRISM approximations or ideal spectrographs.

property cov_inv: Array | None

Precomputed inverse covariance matrix, or None.

Returns:

Inverse covariance matrix, shape (n_pix, n_pix), or None if diagonal noise is assumed.

Return type:

ndarray or None

Notes

The inverse is precomputed at initialization for efficient likelihood evaluation. This property is read-only.

covariance: Array | None = None
classmethod desi_like(wave_obs: Array, resolution: float = 2500.0, **kwargs) Spectroscopy[source]

DESI-like spectroscopic configuration with full line fitting.

Pre-configured with:

  • ~40-line optical catalog (FastSpecFit parity)

  • Marginalized emission lines with CLOUDY priors

  • Doublet constraints enabled

  • Calibration polynomial order 3

Parameters:
  • wave_obs (ndarray, shape (n_pix,)) – Observed-frame wavelength grid [Angstrom].

  • resolution (float, optional) – Spectral resolution R = lambda / delta_lambda (dimensionless). Default: 2500 (DESI standard).

  • **kwargs – Additional keyword arguments passed to Spectroscopy.

Returns:

Fully configured spectroscopy for DESI-like emission-line fitting with all standard settings pre-applied.

Return type:

Spectroscopy

Notes

This configuration mirrors DESI’s optical spectroscopy capabilities, including marginalized emission line amplitudes with CLOUDY-based priors and atomic physics constraints. Suitable for large spectroscopic surveys of emission-line galaxies.

property effective_catalog: object

Return the catalog, falling back to default_13() if not set.

Returns:

The active line catalog (either explicitly configured or default).

Return type:

LineList

Notes

Provides convenient fallback logic: if eline_catalog is None, automatically returns the default 13-line catalog.

eline_broad: bool = False
eline_broad_fwhm_min_kms: float = 500.0
eline_catalog: object | None = None
eline_fix_doublets: bool = True
eline_mode: str = 'off'
eline_prior_sigma: float = 100.0
eline_prior_type: str = 'flat'
eline_prior_width_dex: float = 0.3
classmethod from_fits(fits_path: str, *, wave_col: str = 'WAVELENGTH', flux_col: str = 'FLUX', err_col: str = 'FLUX_ERROR', wave_unit_aa: float = 1.0, flux_unit_cgs: float = 1.0, ext: int | str = 1, resolution: float | Array | None = None, **kwargs) tuple[Spectroscopy, Array, Array][source]

Load from a generic FITS binary table spectrum.

Parameters:
  • fits_path (str) – Path to the FITS file.

  • wave_col (str) – Column name for wavelength. Default: "WAVELENGTH".

  • flux_col (str) – Column name for flux. Default: "FLUX".

  • err_col (str) – Column name for flux error. Default: "FLUX_ERROR".

  • wave_unit_aa (float) – Multiplicative factor to convert wavelength column to Angstrom. E.g., 1e4 for µm input. Default: 1.0 (already Å).

  • flux_unit_cgs (float) – Multiplicative factor to convert flux column to erg/s/cm²/Hz. E.g., 1e-29 for µJy. Default: 1.0 (already CGS).

  • ext (int or str) – FITS extension. Default: 1.

  • resolution (float, ndarray, or None) – Spectral resolution. Default: None.

  • **kwargs – Additional keyword arguments passed to Spectroscopy.

Returns:

(spec_config, flux_obs, flux_err) in Angstrom and CGS.

Return type:

tuple[Spectroscopy, ndarray, ndarray]

Notes

Requires astropy. Generic reader for any FITS binary table spectrum. For JWST x1d files, prefer from_jwst_x1d which handles unit conversion and resolution auto-detection.

classmethod from_jwst_x1d(fits_path: str, *, ext: int | str = 1, resolution: float | Array | None = None, **kwargs) tuple[Spectroscopy, Array, Array][source]

Load from a JWST x1d/x1dints FITS spectrum.

Reads WAVELENGTH (µm), FLUX (µJy), and FLUX_ERROR (µJy) from the specified extension and converts to tengri internal units (Angstrom, erg/s/cm²/Hz).

Parameters:
  • fits_path (str) – Path to the x1d FITS file.

  • ext (int or str) – FITS extension containing the spectrum table. Default: 1.

  • resolution (float, ndarray, or None) – Spectral resolution override. If None, auto-selects based on the FILTER or GRATING header keyword. Default: None.

  • **kwargs – Additional keyword arguments passed to Spectroscopy.

Returns:

(spec_config, flux_obs, flux_err) where flux values are in erg/s/cm²/Hz and wavelengths in Angstrom.

Return type:

tuple[Spectroscopy, ndarray, ndarray]

Notes

Requires astropy. NaN pixels are masked by setting their error to infinity.

Unit conversions: - Wavelength: µm → Å (×10⁴) - Flux: µJy → erg/s/cm²/Hz (×10⁻²⁹)

get_calibration_params() dict[str, Distribution][source]

Return Parameters entries for calibration polynomial.

Returns:

Mapping of calibration coefficient names (cal_c1, …, cal_cN) to Gaussian(0, 0.1) priors. Empty dict if calibration_order == 0.

Return type:

dict[str, Distribution]

Notes

Called by Observation.get_all_params() to register observation-level parameters with the inference engine. Each coefficient has a weak Gaussian prior centered at 0 in log space (unity in linear space).

property has_calibration: bool

Whether a calibration polynomial is configured.

Returns:

True if calibration order is > 0.

Return type:

bool

Notes

Read-only property; determines whether calibration coefficients are registered as free parameters.

property has_covariance: bool

Whether a full covariance matrix is configured.

Returns:

True if covariance matrix is present.

Return type:

bool

Notes

Read-only property; determined at initialization.

property has_eline_fitting: bool

True if emission lines are being fit (marginalized or fitted mode).

Returns:

True if eline_mode is “marginalized” or “fitted”.

Return type:

bool

Notes

Read-only property; determines whether emission line amplitudes are treated as free parameters or analytically marginalized.

property has_lsf: bool

Whether LSF convolution is configured.

Returns:

True if resolution profile is specified.

Return type:

bool

Notes

Read-only property; determines whether LSF convolution is applied in the forward model.

lsf_n_bins: int = 16
property n_pixels: int

Number of spectral pixels.

Returns:

Number of wavelength grid points.

Return type:

int

Notes

Read-only property; computed from the length of wave_obs.

static nirspec_g140m(wave_obs: Array, **kwargs) Spectroscopy[source]

JWST NIRSpec G140M: roughly constant R ~ 1000.

Parameters:
  • wave_obs (ndarray, shape (n_pix,)) – Observed-frame wavelength grid [Angstrom].

  • **kwargs – Additional keyword arguments passed to Spectroscopy.

Returns:

Configured NIRSpec G140M spectroscopy with approximately constant resolution R~1000.

Return type:

Spectroscopy

Notes

Applies wavelength-dependent resolution for NIRSpec’s medium-resolution G140M mode. Resolution is approximately constant at R~1000 across the wavelength range.

static nirspec_prism(wave_obs: Array, **kwargs) Spectroscopy[source]

JWST NIRSpec PRISM: variable R ~ 30-330 (Jakobsen+2022).

Parameters:
  • wave_obs (ndarray, shape (n_pix,)) – Observed-frame wavelength grid [Angstrom].

  • **kwargs – Additional keyword arguments passed to Spectroscopy.

Returns:

Configured NIRSpec PRISM spectroscopy with wavelength-dependent resolution.

Return type:

Spectroscopy

Notes

Applies wavelength-dependent resolution appropriate for NIRSpec’s PRISM mode. Resolution varies from R~30 in the red to R~330 in the blue (Jakobsen et al. 2022).

resolution: float | Array | None = None
sigma_lib_kms: float = 70.0
summary() str[source]

Return a one-line summary of the spectroscopy configuration.

Returns:

Comma-separated summary (e.g., “500 pixels, R=2500, eline=marginalized”). Includes pixel count, resolution, calibration order, emission-line mode, and covariance info.

Return type:

str

Notes

Used for logging and diagnostics. Provides a compact, human-readable representation of the instrument configuration. Intended for display to users, not for programmatic parsing.

wave_obs: Array

NoiseModel

class tengri.NoiseModel(calibration_floor: float | Distribution = 0.0, student_t_dof: float | None = None)[source]

Bases: object

Noise model configuration.

Parameters:
  • calibration_floor (float or Distribution) – Fractional calibration floor added in quadrature with observational noise: sigma_eff = sqrt(sigma_obs^2 + (f_cal * model)^2). A float value becomes Fixed(value); a Distribution (e.g. Uniform(0.01, 0.15)) makes it a free parameter during inference. Default: 0.0 (no calibration floor).

  • student_t_dof (float or None) – Degrees of freedom for a Student-t likelihood (heavier tails for outlier robustness). None uses a standard Gaussian likelihood. Default: None.

Returns:

Noise model instance with configuration validated.

Return type:

NoiseModel

calibration_floor

Fractional calibration uncertainty floor.

Type:

float or Distribution

student_t_dof

Student-t degrees of freedom (or None for Gaussian).

Type:

float or None

Notes

Immutable container: A frozen dataclass. Fields are read-only by convention. —– A frozen dataclass encapsulating noise model configuration. Replaces the older pattern of manually creating noise_frac_cal and noise_dof parameters. Primarily used to register observation-level hyperparameters with the inference engine.

Examples

>>> from tengri import NoiseModel, Uniform
>>> nm = NoiseModel(calibration_floor=Uniform(0.01, 0.1))
>>> list(nm.get_params().keys())
['noise_frac_cal']
calibration_floor: float | Distribution = 0.0
get_params() dict[str, Distribution][source]

Return Parameters entries for the noise model.

Returns:

Mapping of parameter names ("noise_frac_cal", "noise_dof") to Distribution objects. Empty dict if no noise parameters are needed.

Return type:

dict[str, Distribution]

Notes

Called by Observation.get_all_params() to register noise model hyperparameters with the inference engine. Parameters are only included if they are non-trivial (calibration floor > 0 or Student-t dof is not None).

student_t_dof: float | None = None
summary() str[source]

Return a one-line summary of the noise configuration.

Returns:

Single-line summary string with calibration floor and likelihood settings (e.g., "cal floor=0.05 (fixed), Student-t dof=10").

Return type:

str

Notes

Used for logging and diagnostics. Returns "Gaussian (default)" if no custom noise settings are configured.

VIConfig

class tengri.VIConfig(n_samples: int | Callable = 3, n_iterations: int = 10, use_vmap: bool = True, evi_linear_fraction: float = 0.5, draw_linear_kwargs: dict = <factory>, nonlinearly_update_kwargs: dict = <factory>, kl_kwargs: dict = <factory>)[source]

Bases: object

Configuration for geoVI/MGVI/EVI optimize_kl calls.

Parameters:
  • n_samples (int or callable) – Samples per KL iteration. mirror_samples=True (default in NIFTy) doubles this internally, so 3 → 6 effective samples.

  • n_iterations (int) – Number of KL minimization iterations.

  • use_vmap (bool) – Use jax.vmap for residual_map (faster, slightly more memory).

  • evi_linear_fraction (float) – Fraction of iterations using linear_resample before switching to nonlinear_resample in EVI mode.

  • draw_linear_kwargs (dict) – Kwargs for the CG solver that generates each sample.

  • nonlinearly_update_kwargs (dict) – Kwargs for the Newton-CG that inverts the coordinate transform.

  • kl_kwargs (dict) – Kwargs for the outer KL minimization.

n_samples

Samples per KL iteration (doubled by NIFTy’s mirror_samples).

Type:

int or callable

n_iterations

Number of KL minimization iterations.

Type:

int

use_vmap

Whether to use jax.vmap for residual mapping.

Type:

bool

evi_linear_fraction

Fraction of iterations to use linear (MGVI) before nonlinear (geoVI).

Type:

float

draw_linear_kwargs

Conjugate gradient solver configuration for sampling.

Type:

dict

nonlinearly_update_kwargs

Newton-CG configuration for coordinate transform inversion.

Type:

dict

kl_kwargs

Outer KL minimization configuration.

Type:

dict

Notes

Frozen dataclass configuring the variational inference backend. Key fields: method selects the VI algorithm ('vi' / 'vi_native'). The two backends are NOT posterior-equivalent — 'vi_native' is ~19× faster but PSD timescale posteriors differ from the NIFTy path; validate per-problem before swapping.

Examples

>>> from tengri import VIConfig
>>> cfg = VIConfig(n_samples=4, n_iterations=50)
>>> cfg.n_samples
4
draw_linear_kwargs: dict
evi_linear_fraction: float = 0.5
kl_kwargs: dict
n_iterations: int = 10
n_samples: int | Callable = 3
nonlinearly_update_kwargs: dict
use_vmap: bool = True