"""Lazy prediction object for derived physical quantities.
The :class:`Prediction` class provides on-demand computation of derived
quantities from a tengri forward model. Properties are only computed
when first accessed, and intermediate results (SFR, SED, emission lines)
are cached so related quantities share the expensive computation.
Two usage modes
---------------
**Mode 1 — Single-galaxy exploration (lazy):**
.. code-block:: python
pred = model.predict(params)
# SFH quantities — triggers only SFH computation (~100 μs)
pred.sfh.stellar_mass
pred.sfh.mass_weighted_age_gyr
# SED quantities — triggers full SED computation (~300 μs)
pred.sed.l_bol
pred.sed.uv_slope_beta
pred.sed.dn4000
# Emission lines — triggers nebular computation (~200 μs)
pred.lines.halpha
pred.lines.bpt_nii
# Radio and X-ray (empirical scaling relations)
pred.radio.l_1p4ghz
pred.xray.l_x_xrb
**Mode 2 — Population / batch computation (JIT + vmap):**
For computing derived quantities over many parameter sets (posterior
chains, mock catalogs), use the JIT-compatible group methods instead:
.. code-block:: python
import jax
# Batch of 10,000 parameter sets
params_batch = spec.sample_batch(jax.random.PRNGKey(0), n=10_000)
# vmap over SFH quantities — returns SFHQuantities with shape (10000,)
sfh_fn = jax.vmap(model.predict_sfh_quantities)
sfh_batch = sfh_fn(params_batch)
sfh_batch.stellar_mass # shape (10000,)
# vmap over SED quantities — returns SEDQuantities
sed_fn = jax.vmap(model.predict_sed_quantities)
sed_batch = sed_fn(params_batch)
sed_batch.m_uv # shape (10000,)
The lazy :class:`Prediction` object is NOT JIT-compatible (it uses
Python-level caching). For inference loops and batches, always use
the JIT-compatible methods (``predict_sfh_quantities``, etc.).
Caching hierarchy
-----------------
Three computation levels, triggered on demand:
============ ============================ ==========
Level Triggered by Approx cost
============ ============================ ==========
**SFH** Any ``pred.sfh.*`` ~100 μs
**SED** Any ``pred.sed.*`` ~300 μs
**Lines** Any ``pred.lines.*`` ~200 μs
============ ============================ ==========
Each level auto-triggers its dependencies: SED triggers SFH first;
Lines triggers SFH first. Luminosity-weighted quantities in
``pred.sfh`` trigger SED (they need per-bin luminosities).
"""
from typing import NamedTuple
import jax.numpy as jnp
from tengri.components.stellar.sps.dsps_wrapper import (
compute_surviving_mass,
interpolate_mass_remaining,
)
from tengri.utils.sed_quantities import (
KEY_LINES,
compute_balmer_break,
compute_bolometric_luminosity,
compute_dn4000,
compute_fuv_flux,
compute_ionizing_efficiency,
compute_irx,
compute_l_dust_absorbed,
compute_l_radio_1p4ghz_from_sfr,
compute_l_radio_thermal,
compute_l_tir,
compute_l_x_agn,
compute_l_x_xrb,
compute_luminosity_weighted_age,
compute_luminosity_weighted_metallicity,
compute_m_uv,
compute_mass_weighted_age,
compute_mass_weighted_metallicity,
compute_nuv_flux,
compute_q_ir,
compute_rest_uv_color,
compute_uv_luminosity_1600,
compute_uv_slope_beta,
extract_line_luminosity,
)
# ── NamedTuples for JIT-compatible batch computation ──────────────
[docs]
class SFHQuantities(NamedTuple):
"""Derived quantities from the star formation history.
All fields are JAX arrays (scalars). This is a proper JAX pytree,
so it works with ``jax.jit``, ``jax.vmap``, and ``jax.grad``.
Attributes
----------
stellar_mass : jnp.ndarray
Total formed stellar mass [Msun].
stellar_mass_surviving : jnp.ndarray
Surviving mass in living stars + remnants [Msun].
Returns NaN if the mass-remaining table was not loaded.
sfr_100myr : jnp.ndarray
Star formation rate averaged over the last 100 Myr [Msun/yr].
sfr_10myr : jnp.ndarray
Star formation rate averaged over the last 10 Myr [Msun/yr].
ssfr : jnp.ndarray
Specific star formation rate SFR/M* [1/yr].
mass_weighted_age_gyr : jnp.ndarray
Mass-weighted stellar age [Gyr].
mass_weighted_metallicity : jnp.ndarray
Mass-weighted metallicity log10(Z), evolving-Z aware.
Returns
-------
This is a NamedTuple (JAX pytree) returned by
:meth:`SEDModel.predict_sfh_quantities`.
Notes
-----
JAX-compatible array container. All fields are JAX arrays compatible with
``jax.jit`` and ``jax.vmap``. Returned by :meth:`SEDModel.predict_sfh_quantities`
and :attr:`Prediction.sfh` when accessed.
Examples
--------
>>> import jax.numpy as jnp
>>> from tengri import SFHQuantities
>>> q = SFHQuantities(
... stellar_mass=jnp.array(1e10),
... stellar_mass_surviving=jnp.array(6e9),
... sfr_100myr=jnp.array(5.0),
... sfr_10myr=jnp.array(8.0),
... ssfr=jnp.array(5e-10),
... mass_weighted_age_gyr=jnp.array(3.5),
... mass_weighted_metallicity=jnp.array(-0.5),
... )
>>> float(q.stellar_mass)
10000000000.0
>>> "stellar_mass" in q._fields and "sfr_100myr" in q._fields
True
"""
stellar_mass: jnp.ndarray
stellar_mass_surviving: jnp.ndarray
sfr_100myr: jnp.ndarray
sfr_10myr: jnp.ndarray
ssfr: jnp.ndarray
mass_weighted_age_gyr: jnp.ndarray
mass_weighted_metallicity: jnp.ndarray
[docs]
class SEDQuantities(NamedTuple):
"""Derived quantities from the spectral energy distribution.
All fields are JAX arrays. Proper JAX pytree for ``jit``/``vmap``.
Attributes
----------
l_bol : jnp.ndarray
Bolometric luminosity [Lsun].
l_tir : jnp.ndarray
Total infrared luminosity 8–1000 μm [Lsun].
l_dust_absorbed : jnp.ndarray
Dust-absorbed luminosity [Lsun]. Returns NaN if no intrinsic SED.
irx : jnp.ndarray
Infrared excess log10(L_TIR / νLν_1600) [dimensionless].
uv_slope_beta : jnp.ndarray
UV spectral slope β in range 1250–2600 Å [dimensionless].
dn4000 : jnp.ndarray
Narrow 4000 Å break, Balogh et al. 1999 [dimensionless].
balmer_break : jnp.ndarray
Modified Balmer break, Wang et al. 2024 [dimensionless].
m_uv : jnp.ndarray
Absolute UV magnitude at rest-frame 1500 Å [AB].
fuv_flux : jnp.ndarray
Mean flux density in FUV 1000–1700 Å [erg/s/Hz].
nuv_flux : jnp.ndarray
Mean flux density in NUV 1700–3200 Å [erg/s/Hz].
fuv_flux_intrinsic : jnp.ndarray
Dust-free FUV flux [erg/s/Hz]. Returns NaN if no intrinsic SED.
nuv_flux_intrinsic : jnp.ndarray
Dust-free NUV flux [erg/s/Hz]. Returns NaN if no intrinsic SED.
rest_uv_color : jnp.ndarray
Rest-frame U-V color [AB magnitudes].
luminosity_weighted_age_gyr : jnp.ndarray
Luminosity-weighted age [Gyr].
luminosity_weighted_metallicity : jnp.ndarray
Luminosity-weighted metallicity log10(Z).
Returns
-------
This is a NamedTuple (JAX pytree) returned by
:meth:`SEDModel.predict_sed_quantities`.
Notes
-----
JAX-compatible array container. All fields are JAX arrays compatible with
``jax.jit`` and ``jax.vmap``. Returned by :meth:`SEDModel.predict_sed_quantities`
and :attr:`Prediction.sed` when accessed.
Examples
--------
Access via :attr:`Prediction.sed` after calling :meth:`SEDModel.predict`:
.. code-block:: python
pred = model.predict(params)
sed = pred.sed # SEDQuantities NamedTuple
print(float(sed.l_bol)) # bolometric luminosity [Lsun]
print(float(sed.dn4000)) # 4000 Å break strength
print(float(sed.uv_slope_beta)) # UV slope beta
"""
l_bol: jnp.ndarray
l_tir: jnp.ndarray
l_dust_absorbed: jnp.ndarray
irx: jnp.ndarray
uv_slope_beta: jnp.ndarray
dn4000: jnp.ndarray
balmer_break: jnp.ndarray
m_uv: jnp.ndarray
fuv_flux: jnp.ndarray
nuv_flux: jnp.ndarray
fuv_flux_intrinsic: jnp.ndarray
nuv_flux_intrinsic: jnp.ndarray
rest_uv_color: jnp.ndarray
luminosity_weighted_age_gyr: jnp.ndarray
luminosity_weighted_metallicity: jnp.ndarray
[docs]
class EmissionLines(NamedTuple):
"""Key emission line luminosities.
NaN for all fields when no nebular model is active. For doublets
([OII], C IV), the luminosities of both components are summed.
Attributes
----------
lya : jnp.ndarray
Lyman-alpha at 1216 Å [Lsun].
civ_1549 : jnp.ndarray
C IV doublet 1548+1551 Å, summed [Lsun].
oii : jnp.ndarray
[OII] doublet 3726+3729 Å, summed [Lsun].
hbeta : jnp.ndarray
H-beta at 4861 Å [Lsun].
oiii_4959 : jnp.ndarray
[OIII] at 4959 Å [Lsun].
oiii_5007 : jnp.ndarray
[OIII] at 5007 Å [Lsun].
nii_6548 : jnp.ndarray
[NII] at 6548 Å [Lsun].
halpha : jnp.ndarray
H-alpha at 6563 Å [Lsun].
nii_6584 : jnp.ndarray
[NII] at 6584 Å [Lsun].
sii_6717 : jnp.ndarray
[SII] at 6717 Å [Lsun].
sii_6731 : jnp.ndarray
[SII] at 6731 Å [Lsun].
Returns
-------
This is a NamedTuple (JAX pytree) returned by
:meth:`SEDModel.predict_emission_lines`.
Notes
-----
JAX-compatible array container. All fields are JAX arrays compatible with
``jax.jit`` and ``jax.vmap``. Returned by :attr:`Prediction.lines` when
accessed. All fields return NaN if no nebular model is active in the SEDModel.
Examples
--------
Access via :attr:`Prediction.lines` after calling :meth:`SEDModel.predict`:
.. code-block:: python
pred = model.predict(params)
lines = pred.lines # EmissionLines NamedTuple
print(float(lines.halpha)) # H-alpha luminosity [Lsun]
print(float(lines.oiii_5007)) # [OIII] 5007 Å luminosity [Lsun]
# BPT diagram
bpt_x = float(lines.nii_6584 / lines.halpha)
bpt_y = float(lines.oiii_5007 / lines.hbeta)
"""
lya: jnp.ndarray
civ_1549: jnp.ndarray
oii: jnp.ndarray
hbeta: jnp.ndarray
oiii_4959: jnp.ndarray
oiii_5007: jnp.ndarray
nii_6548: jnp.ndarray
halpha: jnp.ndarray
nii_6584: jnp.ndarray
sii_6717: jnp.ndarray
sii_6731: jnp.ndarray
[docs]
class DerivedQuantities(NamedTuple):
"""All derived physical quantities (convenience container).
Returned by ``model.predict_derived()``.
Attributes
----------
sfh : SFHQuantities
Star formation history derived quantities.
sed : SEDQuantities
Spectral energy distribution derived quantities.
Returns
-------
This is a NamedTuple (JAX pytree) returned by
:meth:`SEDModel.predict_derived`.
Notes
-----
JAX-compatible array container combining :class:`SFHQuantities` and
:class:`SEDQuantities`. Compatible with ``jax.jit`` and ``jax.vmap``.
Returned by :meth:`SEDModel.predict_derived`.
Examples
--------
.. code-block:: python
from tengri import DerivedQuantities
derived = model.predict_derived(params)
print(float(derived.sfh.stellar_mass)) # [Msun]
print(float(derived.sed.dn4000)) # 4000 Å break
print(float(derived.sed.uv_slope_beta)) # UV slope β
"""
sfh: SFHQuantities
sed: SEDQuantities
# ── Lazy property group base ──────────────────────────────────────
class _CachedBase:
"""Base class for lazy-cached prediction property groups.
Each subclass accesses the parent :class:`Prediction` object's
shared cache via ``self._pred._cache`` and triggers the appropriate
computation level via ``self._pred._ensure_*()`` methods.
"""
__slots__ = ("_pred",)
def __init__(self, prediction):
self._pred = prediction
# ── SFH properties (lazy) ─────────────────────────────────────────
class SFHProperties(_CachedBase):
"""Lazy property accessor for SFH-derived quantities.
Accessing any property triggers SFH computation (SFR, CSP weights)
if not already cached. Luminosity-weighted quantities additionally
trigger SED computation.
Attributes
----------
stellar_mass : property
Total formed stellar mass [Msun].
stellar_mass_surviving : property
Surviving stellar mass [Msun].
sfr_100myr : property
SFR averaged over 100 Myr [Msun/yr].
sfr_10myr : property
SFR averaged over 10 Myr [Msun/yr].
ssfr : property
Specific SFR [1/yr].
mass_weighted_age_gyr : property
Mass-weighted age [Gyr].
mass_weighted_metallicity : property
Mass-weighted metallicity log10(Z).
luminosity_weighted_age_gyr : property
Luminosity-weighted age [Gyr].
luminosity_weighted_metallicity : property
Luminosity-weighted metallicity log10(Z).
Notes
-----
JAX-compatible array container. Properties are lazy-cached within a
:class:`Prediction` object. Returned by :attr:`Prediction.sfh`.
Not JIT-compatible (uses Python caching). For batch computation, use
JIT-compatible methods :meth:`SEDModel.predict_sfh_quantities` instead.
Examples
--------
>>> pred = model.predict(params)
>>> pred.sfh.stellar_mass # triggers SFH computation
Array(1.23e10, dtype=float64)
>>> pred.sfh.mass_weighted_age_gyr # reuses cached weights
Array(3.45, dtype=float64)
"""
@property
def stellar_mass(self):
"""Total formed stellar mass.
Returns
-------
float
Total stellar mass ever formed [Msun].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sfh()
return jnp.sum(self._pred._cache["weights"])
@property
def stellar_mass_surviving(self):
"""Surviving stellar mass in living stars and remnants.
Returns NaN if the mass-remaining table was not loaded in the
SSP data.
Returns
-------
float
Surviving stellar mass [Msun], or NaN if mass table unavailable.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sfh()
model = self._pred._model
if model.ssp_data.ssp_mass_remaining is None:
return jnp.array(jnp.nan)
log_z = self._pred._cache["p"].get("log_z_abs", 0.0)
mr_at_met = interpolate_mass_remaining(
model.ssp_data.ssp_mass_remaining, model.ssp_data.ssp_lgmet, log_z
)
return compute_surviving_mass(self._pred._cache["weights"], mr_at_met)
@property
def sfr_100myr(self):
"""Star formation rate averaged over the last 100 Myr.
Returns
-------
float
Time-averaged SFR over the last 100 Myr [Msun/yr].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sfh()
sfr = self._pred._cache["sfr"]
model = self._pred._model
mask = model.age_yr <= 1e8
return jnp.where(
jnp.sum(mask) > 0,
jnp.sum(sfr * mask) / jnp.maximum(jnp.sum(mask), 1.0),
sfr[0],
)
@property
def sfr_10myr(self):
"""Star formation rate averaged over the last 10 Myr.
Returns
-------
float
Time-averaged SFR over the last 10 Myr [Msun/yr].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sfh()
sfr = self._pred._cache["sfr"]
model = self._pred._model
mask = model.age_yr <= 1e7
return jnp.where(
jnp.sum(mask) > 0,
jnp.sum(sfr * mask) / jnp.maximum(jnp.sum(mask), 1.0),
sfr[0],
)
@property
def ssfr(self):
"""Specific star formation rate normalized by stellar mass.
Uses surviving mass if available, otherwise formed mass.
Returns
-------
float
Specific star formation rate SFR/M* [1/yr].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
mass_surv = self.stellar_mass_surviving
mass = jnp.where(jnp.isnan(mass_surv), self.stellar_mass, mass_surv)
return self.sfr_100myr / jnp.maximum(mass, 1.0)
@property
def mass_weighted_age_gyr(self):
"""Mass-weighted stellar age.
Returns
-------
float
Age weighted by stellar mass [Gyr].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sfh()
return compute_mass_weighted_age(
self._pred._cache["weights"], self._pred._model.ssp_ages_yr
)
@property
def mass_weighted_metallicity(self):
"""Mass-weighted metallicity.
For single metallicity models, returns the metallicity parameter.
For evolving metallicity, computes Σ(w·Z)/Σ(w).
Returns
-------
float
Metallicity weighted by stellar mass, log10(Z).
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sfh()
p = self._pred._cache["p"]
return compute_mass_weighted_metallicity(
self._pred._cache["weights"],
self._pred._model.ssp_ages_yr,
p.get("log_z_abs", 0.0),
log_z_initial=p.get("log_z_abs_initial"),
log_z_final=p.get("log_z_abs_final"),
)
@property
def luminosity_weighted_age_gyr(self):
"""Luminosity-weighted stellar age.
Accessing this property triggers SED computation.
Returns
-------
float
Age weighted by stellar luminosity [Gyr].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sed()
return compute_luminosity_weighted_age(
self._pred._cache["weights"],
self._pred._cache["ssp_flux_at_z"],
self._pred._model.ssp_ages_yr,
self._pred._model.ssp_data.ssp_wave,
)
@property
def luminosity_weighted_metallicity(self):
"""Luminosity-weighted metallicity.
Accessing this property triggers SED computation.
Returns
-------
float
Metallicity weighted by stellar luminosity, log10(Z).
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sed()
p = self._pred._cache["p"]
return compute_luminosity_weighted_metallicity(
self._pred._cache["weights"],
self._pred._cache["ssp_flux_at_z"],
self._pred._model.ssp_ages_yr,
self._pred._model.ssp_data.ssp_wave,
p.get("log_z_abs", 0.0),
log_z_initial=p.get("log_z_abs_initial"),
log_z_final=p.get("log_z_abs_final"),
)
# ── SED properties (lazy) ─────────────────────────────────────────
class SEDProperties(_CachedBase):
"""Lazy property accessor for SED-derived quantities.
Accessing any property triggers the full SED computation (dust
attenuation, emission, AGN, etc.) if not already cached.
Attributes
----------
l_bol : property
Bolometric luminosity [Lsun].
l_tir : property
Total infrared luminosity [Lsun].
l_dust_absorbed : property
Dust-absorbed luminosity [Lsun].
irx : property
Infrared excess [dimensionless].
uv_slope_beta : property
UV spectral slope [dimensionless].
dn4000 : property
4000 Å break [dimensionless].
balmer_break : property
Balmer break [dimensionless].
m_uv : property
Absolute UV magnitude [AB].
fuv_flux : property
FUV flux density [erg/s/Hz].
nuv_flux : property
NUV flux density [erg/s/Hz].
fuv_flux_intrinsic : property
Dust-free FUV flux [erg/s/Hz].
nuv_flux_intrinsic : property
Dust-free NUV flux [erg/s/Hz].
rest_uv_color : property
Rest-frame U-V color [AB magnitudes].
luminosity_weighted_age_gyr : property
Luminosity-weighted age [Gyr].
luminosity_weighted_metallicity : property
Luminosity-weighted metallicity log10(Z).
Notes
-----
JAX-compatible array container. Properties are lazy-cached within a
:class:`Prediction` object. Returned by :attr:`Prediction.sed`.
Not JIT-compatible (uses Python caching). For batch computation,
use JIT-compatible methods :meth:`SEDModel.predict_sed_quantities`.
Examples
--------
>>> pred = model.predict(params)
>>> pred.sed.l_bol # triggers SED computation
Array(2.5e43, dtype=float64)
>>> pred.sed.uv_slope_beta # reuses cached SED
Array(-1.8, dtype=float64)
"""
def _wave(self):
"""Get rest-frame wavelength array from model."""
return self._pred._model.ssp_data.ssp_wave
def _sed(self):
"""Retrieve cached total SED, computing if necessary."""
self._pred._ensure_sed()
return self._pred._cache["sed_total"]
def _sed_intrinsic(self):
"""Retrieve cached intrinsic (unattenuated) SED, computing if necessary."""
self._pred._ensure_sed()
return self._pred._cache.get("sed_intrinsic")
@property
def l_bol(self):
"""Bolometric luminosity.
Returns
-------
float
Total bolometric luminosity [Lsun].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_bolometric_luminosity(self._sed(), self._wave())
@property
def l_tir(self):
"""Total infrared luminosity.
Returns
-------
float
Integrated luminosity in the 8–1000 μm range [Lsun].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_l_tir(self._sed(), self._wave())
@property
def l_dust_absorbed(self):
"""Dust-absorbed luminosity.
Returns NaN if no intrinsic SED is available.
Returns
-------
float
Luminosity absorbed by dust [Lsun], or NaN if unavailable.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
sed_intr = self._sed_intrinsic()
if sed_intr is None:
return jnp.array(jnp.nan)
self._pred._ensure_sed()
return compute_l_dust_absorbed(sed_intr, self._pred._cache["sed_attenuated"], self._wave())
@property
def irx(self):
"""Infrared excess.
Returns
-------
float
Infrared excess IRX = log10(L_TIR / νLν_1600) [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
l_tir = self.l_tir
l_uv = compute_uv_luminosity_1600(self._sed(), self._wave())
return compute_irx(l_tir, l_uv)
@property
def uv_slope_beta(self):
"""UV spectral slope.
Returns
-------
float
Spectral slope β in the rest-frame range 1250–2600 Å [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_uv_slope_beta(self._sed(), self._wave())
@property
def dn4000(self):
"""Narrow 4000 Å break.
Returns
-------
float
Narrow D_n(4000) break from Balogh et al. 1999 [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_dn4000(self._sed(), self._wave())
@property
def balmer_break(self):
"""Modified Balmer break.
Returns
-------
float
Modified Balmer break from Wang et al. 2024 [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_balmer_break(self._sed(), self._wave())
@property
def m_uv(self):
"""Absolute UV magnitude.
Returns
-------
float
Absolute magnitude at rest-frame 1500 Å [AB].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_m_uv(self._sed(), self._wave())
@property
def fuv_flux(self):
"""Mean flux density in the FUV.
Returns
-------
float
Mean flux density in the FUV 1000–1700 Å range [erg/s/Hz].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_fuv_flux(self._sed(), self._wave())
@property
def nuv_flux(self):
"""Mean flux density in the NUV.
Returns
-------
float
Mean flux density in the NUV 1700–3200 Å range [erg/s/Hz].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_nuv_flux(self._sed(), self._wave())
@property
def fuv_flux_intrinsic(self):
"""Dust-free FUV flux.
Returns NaN if no intrinsic SED is available.
Returns
-------
float
Dust-free flux density in the FUV 1000–1700 Å range [erg/s/Hz],
or NaN if unavailable.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
sed_intr = self._sed_intrinsic()
if sed_intr is None:
return jnp.array(jnp.nan)
return compute_fuv_flux(sed_intr, self._wave())
@property
def nuv_flux_intrinsic(self):
"""Dust-free NUV flux.
Returns NaN if no intrinsic SED is available.
Returns
-------
float
Dust-free flux density in the NUV 1700–3200 Å range [erg/s/Hz],
or NaN if unavailable.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
sed_intr = self._sed_intrinsic()
if sed_intr is None:
return jnp.array(jnp.nan)
return compute_nuv_flux(sed_intr, self._wave())
@property
def rest_uv_color(self):
"""Rest-frame U-V color.
Returns
-------
float
U-V color in rest frame [AB magnitudes].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return compute_rest_uv_color(self._sed(), self._wave())
@property
def luminosity_weighted_age_gyr(self):
"""Luminosity-weighted stellar age.
Returns
-------
float
Age weighted by stellar luminosity [Gyr].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sed()
return compute_luminosity_weighted_age(
self._pred._cache["weights"],
self._pred._cache["ssp_flux_at_z"],
self._pred._model.ssp_ages_yr,
self._pred._model.ssp_data.ssp_wave,
)
@property
def luminosity_weighted_metallicity(self):
"""Luminosity-weighted metallicity.
Returns
-------
float
Metallicity weighted by stellar luminosity, log10(Z).
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sed()
p = self._pred._cache["p"]
return compute_luminosity_weighted_metallicity(
self._pred._cache["weights"],
self._pred._cache["ssp_flux_at_z"],
self._pred._model.ssp_ages_yr,
self._pred._model.ssp_data.ssp_wave,
p.get("log_z_abs", 0.0),
log_z_initial=p.get("log_z_abs_initial"),
log_z_final=p.get("log_z_abs_final"),
)
# ── Emission line properties (lazy) ───────────────────────────────
class LineProperties(_CachedBase):
"""Lazy property accessor for emission line luminosities and diagnostic ratios.
Accessing any line property triggers the nebular computation
if not already cached. All line luminosities are in Lsun. Diagnostic
ratios are dimensionless log10 values.
If no nebular model is active, all line luminosities return NaN
and all ratios return NaN.
Attributes
----------
lya : property
Lyman-alpha [Lsun].
civ_1549 : property
C IV doublet [Lsun].
oii : property
[OII] doublet [Lsun].
hbeta : property
H-beta [Lsun].
oiii_4959 : property
[OIII] 4959 [Lsun].
oiii_5007 : property
[OIII] 5007 [Lsun].
nii_6548 : property
[NII] 6548 [Lsun].
halpha : property
H-alpha [Lsun].
nii_6584 : property
[NII] 6584 [Lsun].
sii_6717 : property
[SII] 6717 [Lsun].
sii_6731 : property
[SII] 6731 [Lsun].
bpt_nii : property
BPT [NII] diagnostic [dimensionless].
bpt_sii : property
BPT [SII] diagnostic [dimensionless].
o3hb : property
[OIII]/Hβ diagnostic [dimensionless].
r23 : property
R23 metallicity diagnostic [dimensionless].
o32 : property
O32 ionization parameter [dimensionless].
balmer_decrement : property
Hα/Hβ ratio [dimensionless].
Notes
-----
JAX-compatible array container. Properties are lazy-cached within a
:class:`Prediction` object. Returned by :attr:`Prediction.lines`.
Not JIT-compatible (uses Python caching). For batch computation,
use JIT-compatible methods :meth:`SEDModel.predict_emission_lines`.
Examples
--------
>>> pred = model.predict(params)
>>> pred.lines.halpha # Hα luminosity in Lsun
Array(1.5e8, dtype=float64)
>>> pred.lines.bpt_nii # log10([NII]6584 / Hα)
Array(-0.45, dtype=float64)
"""
def _get_line(self, name):
"""Extract emission line luminosity from cached grid, computing if necessary."""
self._pred._ensure_lines()
lw = self._pred._cache["line_waves"]
ll = self._pred._cache["line_lums"]
return extract_line_luminosity(lw, ll, KEY_LINES[name])
# --- Individual lines ---
@property
def lya(self):
"""Lyman-alpha at 1216 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("lya")
@property
def civ_1549(self):
"""C IV doublet at 1548+1551 Å.
Returns
-------
float
Summed line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("civ_1549")
@property
def oii(self):
"""[OII] doublet at 3726+3729 Å.
Returns
-------
float
Summed line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("oii")
@property
def hbeta(self):
"""H-beta at 4861 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("hbeta")
@property
def oiii_4959(self):
"""[OIII] at 4959 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("oiii_4959")
@property
def oiii_5007(self):
"""[OIII] at 5007 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("oiii_5007")
@property
def nii_6548(self):
"""[NII] at 6548 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("nii_6548")
@property
def halpha(self):
"""H-alpha at 6563 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("halpha")
@property
def nii_6584(self):
"""[NII] at 6584 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("nii_6584")
@property
def sii_6717(self):
"""[SII] at 6717 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("sii_6717")
@property
def sii_6731(self):
"""[SII] at 6731 Å.
Returns
-------
float
Line luminosity [Lsun], or NaN if no nebular model.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self._get_line("sii_6731")
# --- Diagnostic ratios ---
@property
def bpt_nii(self):
"""BPT-NII diagnostic ratio.
Returns
-------
float
log10([NII]6584 / Hα) for BPT diagram [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return jnp.log10(jnp.maximum(self.nii_6584, 1e-50) / jnp.maximum(self.halpha, 1e-50))
@property
def bpt_sii(self):
"""BPT-SII diagnostic ratio.
Returns
-------
float
log10(([SII]6717+6731) / Hα) for BPT diagram [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
sii_total = self.sii_6717 + self.sii_6731
return jnp.log10(jnp.maximum(sii_total, 1e-50) / jnp.maximum(self.halpha, 1e-50))
@property
def o3hb(self):
"""[OIII]5007/Hβ diagnostic ratio.
Returns
-------
float
log10([OIII]5007 / Hβ), the BPT diagram y-axis [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return jnp.log10(jnp.maximum(self.oiii_5007, 1e-50) / jnp.maximum(self.hbeta, 1e-50))
@property
def r23(self):
"""R23 metallicity diagnostic indicator.
Returns
-------
float
log10(([OII]+[OIII]4959+5007)/Hβ), metallicity indicator [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
numerator = self.oii + self.oiii_4959 + self.oiii_5007
return jnp.log10(jnp.maximum(numerator, 1e-50) / jnp.maximum(self.hbeta, 1e-50))
@property
def o32(self):
"""O32 ionization parameter.
Returns
-------
float
log10([OIII]5007 / [OII]), ionization indicator [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return jnp.log10(jnp.maximum(self.oiii_5007, 1e-50) / jnp.maximum(self.oii, 1e-50))
@property
def balmer_decrement(self):
"""Balmer decrement ratio.
Returns
-------
float
Hα/Hβ intensity ratio [dimensionless]. Case B intrinsic value: 2.86.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self.halpha / jnp.maximum(self.hbeta, 1e-50)
# ── Radio properties (lazy) ───────────────────────────────────────
class RadioProperties(_CachedBase):
"""Lazy property accessor for radio-derived quantities.
These use empirical scaling relations: the FIR-radio correlation
(Bell 2003; Murphy+2011) and free-free emission from the ionizing
photon budget.
Attributes
----------
l_1p4ghz : property
Radio luminosity at 1.4 GHz [erg/s/Hz].
l_thermal : property
Thermal free-free luminosity [erg/s/Hz].
l_nonthermal : property
Non-thermal synchrotron luminosity [erg/s/Hz].
q_ir : property
FIR-radio correlation parameter [dimensionless].
Notes
-----
JAX-compatible array container. Properties are lazy-cached within a
:class:`Prediction` object. Returned by :attr:`Prediction.radio`.
Not JIT-compatible (uses Python caching).
Examples
--------
>>> pred = model.predict(params)
>>> pred.radio.l_1p4ghz # 1.4 GHz luminosity
Array(5.2e28, dtype=float64)
"""
@property
def l_1p4ghz(self):
"""Radio luminosity at 1.4 GHz.
Returns
-------
float
Radio luminosity density at 1.4 GHz [erg/s/Hz], from SFR scaling.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
sfr = self._pred.sfh.sfr_100myr
return compute_l_radio_1p4ghz_from_sfr(sfr)
@property
def l_thermal(self):
"""Thermal radio luminosity at 1.4 GHz.
Returns
-------
float
Free-free radio luminosity density at 1.4 GHz [erg/s/Hz].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
q_h = self._pred.ionizing.q_h
return compute_l_radio_thermal(q_h)
@property
def l_nonthermal(self):
"""Non-thermal radio luminosity at 1.4 GHz.
Returns
-------
float
Synchrotron radio luminosity density [erg/s/Hz].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self.l_1p4ghz - self.l_thermal
@property
def q_ir(self):
"""FIR-radio correlation parameter.
Returns
-------
float
FIR-radio correlation parameter q_TIR [dimensionless].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
l_tir = self._pred.sed.l_tir
return compute_q_ir(l_tir, self.l_1p4ghz)
# ── X-ray properties (lazy) ───────────────────────────────────────
class XRayProperties(_CachedBase):
"""Lazy property accessor for X-ray derived quantities.
Uses empirical scaling relations from Lehmer et al. (2010, 2016) for
X-ray binaries and Duras et al. (2020) for AGN bolometric corrections.
Attributes
----------
l_x_xrb : property
X-ray binary luminosity [erg/s].
l_x_agn : property
AGN X-ray luminosity [erg/s].
l_x_total : property
Total X-ray luminosity [erg/s].
Notes
-----
JAX-compatible array container. Properties are lazy-cached within a
:class:`Prediction` object. Returned by :attr:`Prediction.xray`.
Not JIT-compatible (uses Python caching).
Examples
--------
>>> pred = model.predict(params)
>>> pred.xray.l_x_xrb # XRB luminosity (0.5-8 keV)
Array(3.1e40, dtype=float64)
"""
@property
def l_x_xrb(self):
"""X-ray luminosity from X-ray binaries.
Returns
-------
float
XRB X-ray luminosity in 0.5–8 keV band [erg/s].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
sfr = self._pred.sfh.sfr_100myr
mstar = self._pred.sfh.stellar_mass
return compute_l_x_xrb(sfr, mstar)
@property
def l_x_agn(self):
"""AGN X-ray luminosity.
Returns
-------
float
AGN X-ray luminosity in 2–10 keV band [erg/s].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_sed()
agn_bol = self._pred._cache.get("agn_bol_erg", 0.0)
return compute_l_x_agn(agn_bol)
@property
def l_x_total(self):
"""Total X-ray luminosity.
Returns
-------
float
Combined XRB and AGN X-ray luminosity [erg/s].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
return self.l_x_xrb + self.l_x_agn
# ── Ionizing properties (lazy) ────────────────────────────────────
class IonizingProperties(_CachedBase):
"""Lazy property accessor for ionizing photon budget quantities.
The ionizing photon rate Q_H is extracted from the nebular model
backend (Cloudy grid or Cue emulator). If no nebular model is
active, returns NaN.
Attributes
----------
q_h : property
Ionizing photon production rate [photons/s].
xi_ion : property
Ionizing photon production efficiency [Hz/erg].
Notes
-----
JAX-compatible array container. Properties are lazy-cached within a
:class:`Prediction` object. Returned by :attr:`Prediction.ionizing`.
Not JIT-compatible (uses Python caching).
Examples
--------
>>> pred = model.predict(params)
>>> pred.ionizing.xi_ion # ionizing efficiency
Array(25.3, dtype=float64)
"""
@property
def q_h(self):
"""Total ionizing photon production rate.
Returns NaN if no nebular model is active.
Returns
-------
float
Ionizing photon production rate [photons/s], or NaN if unavailable.
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._pred._ensure_lines()
return self._pred._cache.get("q_h_total", jnp.array(jnp.nan))
@property
def xi_ion(self):
"""Ionizing photon production efficiency.
Defined as Q_H / L_UV(1600 Å). Key parameter for cosmic
reionization studies. Typical values: 25.0–25.6.
Returns
-------
float
Ionizing photon efficiency log10(ξ_ion) [Hz/erg].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
q_h = self.q_h
l_uv = compute_uv_luminosity_1600(self._pred.sed._sed(), self._pred.sed._wave())
return compute_ionizing_efficiency(q_h, l_uv)
# ── Main Prediction class ─────────────────────────────────────────
[docs]
class Prediction:
"""Lazy prediction object with on-demand computation of derived quantities.
Created via ``model.predict(params)``. Properties are computed on
first access and cached. The cache is shared across all property
groups (``sfh``, ``sed``, ``lines``, ``radio``, ``xray``,
``ionizing``), so related quantities share the expensive
intermediates.
Parameters
----------
model : SEDModel
The tengri SEDModel instance.
params : dict
Parameter values (public names).
Attributes
----------
sfh : SFHProperties
Star formation history derived quantities. Lazy accessor.
sed : SEDProperties
Spectral energy distribution derived quantities. Lazy accessor.
lines : LineProperties
Emission line luminosities and diagnostic ratios. Lazy accessor.
radio : RadioProperties
Radio-derived quantities from empirical relations. Lazy accessor.
xray : XRayProperties
X-ray derived quantities from empirical relations. Lazy accessor.
ionizing : IonizingProperties
Ionizing photon budget quantities. Lazy accessor.
Returns
-------
Prediction
Lazy prediction object with cached computed quantities.
Notes
-----
This class is NOT JIT-compatible due to Python-level caching. For
batch computations over many parameter sets (MCMC chains, mock
catalogs), use the JIT-compatible methods :meth:`SEDModel.predict_sfh_quantities`,
:meth:`SEDModel.predict_sed_quantities`, etc. instead. Those return
JAX pytrees (:class:`SFHQuantities`, :class:`SEDQuantities`,
:class:`DerivedQuantities`) suitable for :func:`jax.vmap`,
:func:`jax.jit`, and :func:`jax.grad`.
Examples
--------
**Two equivalent ways to access derived quantities:**
>>> pred = model.predict(params)
>>> pred.stellar_mass # flat shortcut
>>> pred.sfh.stellar_mass # grouped form (same value)
>>> pred.dn4000 # flat
>>> pred.sed.dn4000 # grouped
>>> pred.halpha # flat
>>> pred.lines.halpha # grouped
The grouped form (``pred.sfh``, ``pred.sed``, ``pred.lines``,
``pred.radio``, ``pred.xray``, ``pred.ionizing``) exposes every
derived quantity. The top-level shortcuts cover the most-used
quantities for quick access; for less-common ones use the grouped
form. Both share the same lazy cache, so accessing a quantity by
either route triggers computation only once.
**Accessing the full SED or photometry:**
>>> pred.sed_array # shape (n_wave,)
>>> pred.photometry # shape (n_filters,)
**For batch computation, use JIT-compatible methods instead:**
>>> sfh_batch = jax.vmap(model.predict_sfh_quantities)(params_batch)
"""
__slots__ = ("_cache", "_model", "_params", "ionizing", "lines", "radio", "sed", "sfh", "xray")
def __init__(self, model, params):
self._model = model
self._params = params
self._cache = {}
self.sfh = SFHProperties(self)
self.sed = SEDProperties(self)
self.lines = LineProperties(self)
self.radio = RadioProperties(self)
self.xray = XRayProperties(self)
self.ionizing = IonizingProperties(self)
def _ensure_sfh(self):
"""Populate SFH cache (SFR history, age weights, internal params).
Reads the orchestrator's :class:`PipelineState` to keep SFH-only
consumers (``stellar_mass``, ``sfr_*``, ``mass_weighted_age``)
and SED-consuming consumers
(``luminosity_weighted_age``, ``dn4000``) on the same numerics.
``weights`` is the orchestrator's DSPS-canonical
``age_weights`` (Msun per SSP age bin); ``sfr`` is
``sfr_history`` on the SFH lookback grid (same shape as
``model.age_yr``). The orchestrator state itself is cached on
``_state`` so :meth:`_ensure_sed` can reuse it without a
second forward-pass.
"""
if "weights" in self._cache:
return
p = self._model._get_internal_params(self._params)
state = self._model.predict_via_orchestrator(self._params)
derived = state.derived
# The orchestrator's stellar adapter integrates the SFH on
# ``spec.n_grid`` (default 64) regardless of whether the model
# is stochastic. Legacy ``SEDModel`` uses ``n_grid=256`` for
# non-stochastic configs, so cache consumers index ``sfr``
# against ``model.age_yr`` of length 256. Resample the
# orchestrator's SFR history to the legacy grid so masks like
# ``model.age_yr <= 1e8`` still align.
sfh_grid = jnp.asarray(derived["sfh_grid_lbt_yr"])
sfr_history = jnp.asarray(derived["sfr_history"])
sfr_on_legacy_grid = jnp.interp(self._model.age_yr, sfh_grid, sfr_history)
self._cache.update(
{
"p": p,
"sfr": sfr_on_legacy_grid,
"weights": jnp.asarray(derived["age_weights"]),
"_state": state,
}
)
def _ensure_sed(self):
"""Populate SED cache from the orchestrator's PipelineState.
Re-uses the state computed in :meth:`_ensure_sfh` (cached on
``self._cache["_state"]``). Reconstructs the legacy cache
contract:
- ``sed_total`` ← ``state.sed_intrinsic`` (post-dust total)
- ``sed_intrinsic`` ← ``sum(lnu_age, axis=0)`` (pre-dust stellar)
- ``sed_attenuated`` ← ``state.derived["sed_dust_attenuated"]``
- ``ssp_flux_at_z`` ← ``lnu_age / (age_weights * LSUN_ERG)``
(safe-divided where ``age_weights`` is zero)
- ``agn_bol_erg`` ← ``state.derived["L_agn_bol"]`` if present.
"""
if "sed_total" in self._cache:
return
self._ensure_sfh()
state = self._cache["_state"]
derived = state.derived
self._cache["sed_total"] = state.sed_intrinsic
lnu_age = derived.get("lnu_age")
if lnu_age is not None:
from tengri.utils.physics_constants import L_SUN
lnu_age_arr = jnp.asarray(lnu_age)
self._cache["sed_intrinsic"] = jnp.sum(lnu_age_arr, axis=0)
aw = jnp.asarray(self._cache["weights"])
aw_safe = jnp.maximum(aw, 1e-30)
self._cache["ssp_flux_at_z"] = lnu_age_arr / (aw_safe[:, None] * L_SUN)
sed_attenuated = derived.get("sed_dust_attenuated")
if sed_attenuated is not None:
self._cache["sed_attenuated"] = jnp.asarray(sed_attenuated)
if "L_agn_bol" in derived:
self._cache["agn_bol_erg"] = jnp.asarray(derived["L_agn_bol"])
def _ensure_lines(self):
"""Compute and cache nebular emission line luminosities.
Reads the discrete catalogue published by
:class:`~tengri.components.nebular.component.NebularSEDComponent`
(``state.derived["line_waves"]`` / ``["line_lums"]``). Matches
legacy-path luminosities within numerical tolerance for both
Cue and CloudyGrid backends after the Phase II-3 PR 5b'
orchestrator fixes (``age_weights`` plumbing +
``neb_logZ_gas`` translation; commit b7dff1b).
"""
if "line_waves" in self._cache:
return
self._ensure_sfh()
model = self._model
backend = model._nebular_backend
if backend is None or not hasattr(backend, "predict_nebular_line_luminosities"):
self._cache["line_waves"] = jnp.array([])
self._cache["line_lums"] = jnp.array([])
self._cache["q_h_total"] = jnp.array(jnp.nan)
return
# Pull the catalogue from the orchestrator's NebularSEDComponent
# publication. BakedIn / Shock backends won't publish it; fall
# back to all-NaN for those (matches the legacy "no catalogue"
# behaviour without raising).
state = model.predict_via_orchestrator(self._params)
derived = state.derived
if "line_waves" in derived and "line_lums" in derived:
self._cache["line_waves"] = jnp.asarray(derived["line_waves"])
self._cache["line_lums"] = jnp.asarray(derived["line_lums"])
else:
self._cache["line_waves"] = jnp.array([])
self._cache["line_lums"] = jnp.array([])
# Q_H: compute from backend's precomputed table if available.
# Uses the orchestrator-published ``age_weights`` (Msun/bin) +
# ``log_metallicity_history`` (present-day value) so the value
# matches the legacy path even when other consumers of the
# cache still go through the legacy ``_ensure_sed``.
weights_orch = derived.get("age_weights")
log_z_history = derived.get("log_metallicity_history")
if (
weights_orch is not None
and log_z_history is not None
and hasattr(backend, "_qh_table")
and backend._qh_table is not None
):
log_z = jnp.asarray(log_z_history)[0]
young_idx = backend._young_idx
young_ages = model.ssp_log_ages_yr[young_idx]
young_weights = jnp.asarray(weights_orch)[young_idx]
def _qh_one_bin(log_age_i, w_i):
"""Ionising photon production rate for one age bin."""
return w_i * backend._get_qh_at(log_z, log_age_i)
import jax
q_h_per_bin = jax.vmap(_qh_one_bin)(young_ages, young_weights)
neb_fesc = jnp.asarray(self._params.get("neb_fesc", 0.0))
self._cache["q_h_total"] = jnp.sum(q_h_per_bin) * (1.0 - neb_fesc)
else:
self._cache["q_h_total"] = jnp.array(jnp.nan)
@property
def sed_array(self):
"""Full rest-frame SED array.
Returns
-------
ndarray, shape (n_wave,)
Total spectral energy distribution [erg/s/Hz].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
"""
self._ensure_sed()
return self._cache["sed_total"]
@property
def photometry(self):
"""Observed photometric flux densities.
Returns
-------
ndarray, shape (n_filters,)
Photometry at the filters defined in the SEDModel [erg/s/cm²/Hz].
Notes
-----
**JIT-compatible**: no — Python property accessor. Use in postprocessing,
not inside :func:`jax.jit`.
Examples
--------
.. code-block:: python
pred = model.predict(params)
phot = pred.photometry # ndarray, shape (n_filters,)
print(phot.shape) # e.g. (8,) for 8 photometric bands
"""
return self._model.predict_photometry(self._params)
# ── Top-level shortcuts to grouped derived quantities ───────────────
# ``pred.stellar_mass`` and ``pred.sfh.stellar_mass`` return the same value;
# the flat form is for tab-completion convenience and aligns with how
# astronomers typically refer to derived quantities (no domain prefix).
# Where two groups expose the same name (e.g. luminosity_weighted_*),
# the flat shortcut points to the SED version — that's the canonical
# "luminosity-weighted" meaning (uses attenuated SED, not stellar-only).
# --- SFH-derived (forward to pred.sfh) ---
@property
def stellar_mass(self):
"""Total stellar mass formed [M☉]. Same as ``pred.sfh.stellar_mass``."""
return self.sfh.stellar_mass
@property
def stellar_mass_surviving(self):
"""Surviving stellar + remnant mass [M☉]. Same as ``pred.sfh.stellar_mass_surviving``."""
return self.sfh.stellar_mass_surviving
@property
def sfr_100myr(self):
"""SFR averaged over last 100 Myr [M☉/yr]. Same as ``pred.sfh.sfr_100myr``."""
return self.sfh.sfr_100myr
@property
def sfr_10myr(self):
"""SFR averaged over last 10 Myr [M☉/yr]. Same as ``pred.sfh.sfr_10myr``."""
return self.sfh.sfr_10myr
@property
def ssfr(self):
"""Specific SFR [yr⁻¹]. Same as ``pred.sfh.ssfr``."""
return self.sfh.ssfr
@property
def mass_weighted_age_gyr(self):
"""Mass-weighted stellar age [Gyr]. Same as ``pred.sfh.mass_weighted_age_gyr``."""
return self.sfh.mass_weighted_age_gyr
@property
def mass_weighted_metallicity(self):
"""Mass-weighted log₁₀(Z/Z☉). Same as ``pred.sfh.mass_weighted_metallicity``."""
return self.sfh.mass_weighted_metallicity
# --- SED-derived (forward to pred.sed) ---
@property
def l_bol(self):
"""Bolometric luminosity [L☉]. Same as ``pred.sed.l_bol``."""
return self.sed.l_bol
@property
def l_tir(self):
"""Total infrared (8–1000 μm) luminosity [L☉]. Same as ``pred.sed.l_tir``."""
return self.sed.l_tir
@property
def l_dust_absorbed(self):
"""Dust-absorbed luminosity [L☉]. Same as ``pred.sed.l_dust_absorbed``."""
return self.sed.l_dust_absorbed
@property
def irx(self):
"""Infrared excess L_TIR / L_UV(1600 Å). Same as ``pred.sed.irx``."""
return self.sed.irx
@property
def uv_slope_beta(self):
"""UV slope β in f_λ ∝ λ^β. Same as ``pred.sed.uv_slope_beta``."""
return self.sed.uv_slope_beta
@property
def dn4000(self):
"""D_n(4000) break ratio. Same as ``pred.sed.dn4000``."""
return self.sed.dn4000
@property
def balmer_break(self):
"""Balmer break flux ratio. Same as ``pred.sed.balmer_break``."""
return self.sed.balmer_break
@property
def m_uv(self):
"""Absolute magnitude at 1500 Å. Same as ``pred.sed.m_uv``."""
return self.sed.m_uv
@property
def fuv_flux(self):
"""FUV flux at 1500 Å [erg/s/cm²]. Same as ``pred.sed.fuv_flux``."""
return self.sed.fuv_flux
@property
def nuv_flux(self):
"""NUV flux at 2300 Å [erg/s/cm²]. Same as ``pred.sed.nuv_flux``."""
return self.sed.nuv_flux
@property
def fuv_flux_intrinsic(self):
"""Dust-free FUV flux. Same as ``pred.sed.fuv_flux_intrinsic``."""
return self.sed.fuv_flux_intrinsic
@property
def nuv_flux_intrinsic(self):
"""Dust-free NUV flux. Same as ``pred.sed.nuv_flux_intrinsic``."""
return self.sed.nuv_flux_intrinsic
@property
def rest_uv_color(self):
"""Rest-frame UV color (f_1500 − f_2300). Same as ``pred.sed.rest_uv_color``."""
return self.sed.rest_uv_color
@property
def luminosity_weighted_age_gyr(self):
"""Luminosity-weighted age [Gyr]. Same as ``pred.sed.luminosity_weighted_age_gyr``.
Both ``pred.sfh`` and ``pred.sed`` define this; the top-level shortcut
forwards to the SED version (canonical "luminosity-weighted" using the
attenuated stellar SED).
"""
return self.sed.luminosity_weighted_age_gyr
@property
def luminosity_weighted_metallicity(self):
"""Luminosity-weighted log₁₀(Z/Z☉).
Same as ``pred.sed.luminosity_weighted_metallicity``.
"""
return self.sed.luminosity_weighted_metallicity
# --- Emission lines (forward to pred.lines) ---
@property
def halpha(self):
"""Hα 6564 Å luminosity [erg/s]. Same as ``pred.lines.halpha``."""
return self.lines.halpha
@property
def hbeta(self):
"""Hβ 4862 Å luminosity [erg/s]. Same as ``pred.lines.hbeta``."""
return self.lines.hbeta
@property
def oiii_5007(self):
"""[O III] 5007 Å luminosity [erg/s]. Same as ``pred.lines.oiii_5007``."""
return self.lines.oiii_5007
@property
def balmer_decrement(self):
"""Hα/Hβ flux ratio. Same as ``pred.lines.balmer_decrement``."""
return self.lines.balmer_decrement
# --- Ionizing budget (forward to pred.ionizing) ---
@property
def q_h(self):
"""Total ionizing photon production rate [s⁻¹]. Same as ``pred.ionizing.q_h``."""
return self.ionizing.q_h
@property
def xi_ion(self):
"""Ionizing photon production efficiency [Hz·erg⁻¹]. Same as ``pred.ionizing.xi_ion``."""
return self.ionizing.xi_ion