Building Models¶
A Parameters object is just a declarative dict: keys name physical parameters, values are priors (Uniform, Gaussian, Fixed, …) or string flags for component choices. Free vs fixed is tracked automatically from the prior type.
This notebook varies one structural axis at a time — SFH family, dust attenuation law, dust emission template — and plots the resulting SEDs side by side, then times a few JIT-compiled forward calls to show the one-off compile cost amortizes.
Setup¶
[1]:
import os
os.environ.setdefault("TENGRI_NO_BACKGROUND_COMPILE", "1")
import time
from pathlib import Path
import jax
import numpy as np
import matplotlib.pyplot as plt
from tengri import (
Fixed,
Observation,
Parameters,
Photometry,
SEDModel,
Uniform,
load_ssp_data,
)
from tengri import cosmology, plot, units
plot.setup_style()
# Load SSP grid (MIST + C3K, Chabrier IMF)
_ssp_name = "ssp_mist_c3k_a_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
_repo_root = next(
p for p in [Path.cwd(), *Path.cwd().parents] if (p / "data" / _ssp_name).exists()
)
ssp = load_ssp_data(str(_repo_root / "data" / _ssp_name))
# Lightweight photometry (pre-downloaded, no SVO API calls)
# Span optical → IR for meaningful SED comparisons
filter_names = [
"galex_fuv", # Far-UV (1539 Å)
"galex_nuv", # Near-UV (2316 Å)
"sdss_u", # Optical (blue)
"sdss_g", # Optical (green)
"sdss_r", # Optical (red)
"sdss_i", # Optical (near-IR)
"sdss_z", # Optical (far-red)
"wise_w1", # Mid-IR (3.4 μm)
"wise_w2", # Mid-IR (4.6 μm)
"wise_w3", # Mid-IR (12 μm)
]
photometry = Photometry.from_names(filter_names)
observation = Observation(photometry=photometry)
print(f"Loaded SSP grid: {_ssp_name}")
print(f"Photometry: {photometry.n_filters} filters spanning UV→IR")
print(f" {', '.join([b.replace('galex_', 'GALEX-').replace('sdss_', 'SDSS-').replace('wise_', 'WISE-') for b in filter_names])}")
W0507 04:35:18.704476 13033896 cpp_gen_intrinsics.cc:74] Empty bitcode string provided for eigen. Optimizations relying on this IR will be disabled.
W0507 04:35:19.016513 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:20.695061 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:20.699328 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:20.731578 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:20.747652 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
Loaded SSP grid: ssp_mist_c3k_a_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5
Photometry: 10 filters spanning UV→IR
GALEX-fuv, GALEX-nuv, SDSS-u, SDSS-g, SDSS-r, SDSS-i, SDSS-z, WISE-w1, WISE-w2, WISE-w3
Structural choices, priors, and the free/fixed split¶
Every Parameters(...) instance holds:
Structural choices — flags like
mean_sfh_type,dust_law_bc,dust_emissionthat select from the registry of physics models.Priors on free parameters — distributions (
Uniform,Gaussian) bound to parameter names.Fixed values — parameters wrapped in
Fixed(value)are pinned and never appear inspec.free_params.
Let’s build three Parameters objects of increasing complexity and inspect each one.
[2]:
# ─── Model 1: Minimal (just SFH + metallicity) ──────────────────────
print("MODEL 1: Minimal (SFH only)")
spec_minimal = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
sfh_tsnorm_skew=Uniform(-1.0, 1.0),
sfh_tsnorm_trunc=Uniform(1.0, 10.0),
met_logzsol=Uniform(-1.5, 0.3),
redshift=Fixed(0.05),
apply_igm=False,
)
print(f"Free parameters ({len(spec_minimal.free_params)}): {spec_minimal.free_params}")
print(f"Fixed parameters ({len(spec_minimal.fixed_params)}): {spec_minimal.fixed_params}")
# ─── Model 2: Add dust attenuation ────────────────────────────────────
print("MODEL 2: + Dust attenuation (two-component)")
spec_dust_atten = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
sfh_tsnorm_skew=Uniform(-1.0, 1.0),
sfh_tsnorm_trunc=Uniform(1.0, 10.0),
met_logzsol=Uniform(-1.5, 0.3),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Uniform(0.0, 2.0),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
redshift=Fixed(0.05),
apply_igm=False,
)
print(f"Free parameters ({len(spec_dust_atten.free_params)}): {spec_dust_atten.free_params}")
print(f"Fixed parameters ({len(spec_dust_atten.fixed_params)}): {spec_dust_atten.fixed_params}")
# ─── Model 3: Full energy-balance model ──────────────────────────────
print("MODEL 3: Full energy-balance (+ dust IR emission)")
spec_full = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
sfh_tsnorm_skew=Uniform(-1.0, 1.0),
sfh_tsnorm_trunc=Uniform(1.0, 10.0),
met_logzsol=Uniform(-1.5, 0.3),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Uniform(0.0, 2.0),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(0.05),
apply_igm=False,
)
print(f"Free parameters ({len(spec_full.free_params)}): {spec_full.free_params}")
print(f"Fixed parameters ({len(spec_full.fixed_params)}): {spec_full.fixed_params}")
======================================================================
MODEL 1: Minimal (SFH only)
======================================================================
Free parameters (8): ['dust_tau_bc', 'dust_tau_diff', 'met_logzsol', 'sfh_tsnorm_log_peak_sfr', 'sfh_tsnorm_peak_lbt_gyr', 'sfh_tsnorm_skew', 'sfh_tsnorm_trunc', 'sfh_tsnorm_width_gyr']
Fixed parameters (10): ['dust_Rv', 'dust_bump_strength', 'dust_delta', 'dust_f_obscuration', 'dust_slope', 'met_alpha_fe', 'noise_dof', 'noise_frac_cal', 'redshift', 'sigma_v_kms']
======================================================================
MODEL 2: + Dust attenuation (two-component)
======================================================================
Free parameters (8): ['dust_tau_bc', 'dust_tau_diff', 'met_logzsol', 'sfh_tsnorm_log_peak_sfr', 'sfh_tsnorm_peak_lbt_gyr', 'sfh_tsnorm_skew', 'sfh_tsnorm_trunc', 'sfh_tsnorm_width_gyr']
Fixed parameters (10): ['dust_Rv', 'dust_bump_strength', 'dust_delta', 'dust_f_obscuration', 'dust_slope', 'met_alpha_fe', 'noise_dof', 'noise_frac_cal', 'redshift', 'sigma_v_kms']
======================================================================
MODEL 3: Full energy-balance (+ dust IR emission)
======================================================================
Free parameters (8): ['dust_tau_bc', 'dust_tau_diff', 'met_logzsol', 'sfh_tsnorm_log_peak_sfr', 'sfh_tsnorm_peak_lbt_gyr', 'sfh_tsnorm_skew', 'sfh_tsnorm_trunc', 'sfh_tsnorm_width_gyr']
Fixed parameters (23): ['dust_Rv', 'dust_T', 'dust_T_cold', 'dust_T_warm', 'dust_alpha_dale', 'dust_alpha_dl14', 'dust_alpha_mir', 'dust_beta_ir', 'dust_bump_strength', 'dust_delta', 'dust_eta_balance', 'dust_f_obscuration', 'dust_gamma_dl', 'dust_log_ssfr', 'dust_qhac', 'dust_qpah', 'dust_slope', 'dust_umin', 'met_alpha_fe', 'noise_dof', 'noise_frac_cal', 'redshift', 'sigma_v_kms']
Vary the SFH family¶
The parameter registry is structure-aware: when you swap mean_sfh_type, the free-parameter list updates automatically. Each SFH family carries different parameter names and priors.
[3]:
sfh_families = [
("tsnorm", {
"sfh_tsnorm_log_peak_sfr": np.log10(15.0),
"sfh_tsnorm_peak_lbt_gyr": 3.0,
"sfh_tsnorm_width_gyr": 2.5,
"sfh_tsnorm_skew": 0.2,
"sfh_tsnorm_trunc": 4.0,
}),
("dpl", {
"sfh_dpl_log_peak_sfr": np.log10(15.0),
"sfh_dpl_alpha": 2.0,
"sfh_dpl_beta": 1.5,
"sfh_dpl_tau_gyr": 2.0,
}),
("dexp", {
# Delayed exponential SFH: exponential decay exp(-t/tau) from z=0.
# Produces star-forming main sequence morphology.
"sfh_dexp_log_peak_sfr": np.log10(15.0),
"sfh_dexp_tau_gyr": 2.5,
}),
("lnorm", {
"sfh_lnorm_log_peak_sfr": np.log10(15.0),
"sfh_lnorm_peak_lbt_gyr": 3.0,
"sfh_lnorm_width_gyr": 0.6,
}),
("dirichlet", {
# Non-parametric Dirichlet SFH: piecewise constant SFR in 6 age bins,
# constrained by Dirichlet prior. Fractions (z_*) are RAW simplex values
# on Uniform(0.01, 0.99), NOT log values. This sequence represents
# "rising then plateauing" star formation: higher fractions at late times.
"sfh_dir_log_total_mass": np.log10(1e10),
"sfh_dir_z_0": 0.6, # Earliest bin: 60% of remaining mass
"sfh_dir_z_1": 0.5, # 50% of remaining
"sfh_dir_z_2": 0.4, # 40% of remaining
"sfh_dir_z_3": 0.3, # 30% of remaining
"sfh_dir_z_4": 0.2, # 20% of remaining
"sfh_dir_z_5": 0.15, # Most recent: 15% of remaining
}),
]
print("\nSFH Family Comparison")
for sfh_name, _ in sfh_families:
spec_sfh = Parameters(
mean_sfh_type=sfh_name,
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(0.05),
apply_igm=False,
)
sfh_params = [p for p in spec_sfh.free_params if p.startswith("sfh_")]
print(f"{sfh_name:20s} {len(sfh_params):2d} SFH params: {', '.join(sfh_params)}")
print("\n[TIP] Use tengri.describe(name) to inspect any SFH family or component in detail.")
print("Example: tengri.describe('delayed_bq') shows the parametrization and physics.")
SFH Family Comparison
======================================================================
tsnorm 5 SFH params: sfh_tsnorm_log_peak_sfr, sfh_tsnorm_peak_lbt_gyr, sfh_tsnorm_skew, sfh_tsnorm_trunc, sfh_tsnorm_width_gyr
dpl 4 SFH params: sfh_dpl_alpha, sfh_dpl_beta, sfh_dpl_log_peak_sfr, sfh_dpl_tau_gyr
dexp 2 SFH params: sfh_dexp_log_peak_sfr, sfh_dexp_tau_gyr
lnorm 3 SFH params: sfh_lnorm_log_peak_sfr, sfh_lnorm_peak_lbt_gyr, sfh_lnorm_width_gyr
dirichlet 7 SFH params: sfh_dir_log_total_mass, sfh_dir_z_0, sfh_dir_z_1, sfh_dir_z_2, sfh_dir_z_3, sfh_dir_z_4, sfh_dir_z_5
[TIP] Use tengri.describe(name) to inspect any SFH family or component in detail.
Example: tengri.describe('delayed_bq') shows the parametrization and physics.
SED under different SFH families¶
Build a truth dict for each SFH family and compute the resulting SEDs. Notice how the spectral shape — especially the recent star formation signature — changes with the SFH family.
New API surface: Use model.predict_sfh_quantities(truth) to extract mass-weighted age and other diagnostics; overlay on SFR(t) to teach the model’s expectations about age structure.
[4]:
n_sfh = len(sfh_families)
fig = plt.figure(figsize=(14, 2.5 * n_sfh))
gs = fig.add_gridspec(n_sfh, 2, hspace=0.35, wspace=0.25)
z = 0.05
dl_cm = float(cosmology.luminosity_distance(z))
for row, (sfh_name, truth_sfh) in enumerate(sfh_families):
# Build model for this SFH family
spec = Parameters(
mean_sfh_type=sfh_name,
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(z),
apply_igm=False,
)
model = SEDModel(spec, ssp, observation=observation)
# Build truth dict
truth = {
"met_logzsol": -0.1,
"dust_tau_bc": 0.5,
"dust_tau_diff": 0.3,
"dust_slope": -0.7,
"redshift": z,
}
truth.update(truth_sfh)
# LEFT: SFR(t) curve
ax_sfr = fig.add_subplot(gs[row, 0])
sfr_curve = model.predict_sfh(truth)
t_lookback = np.asarray(sfr_curve["t_gyr"])
sfr_values = np.asarray(sfr_curve["sfr_mean"])
color = plot.COLORS["seq"][row % len(plot.COLORS["seq"])]
ax_sfr.loglog(t_lookback, np.maximum(sfr_values, 1e-3), lw=2,
color=color, label=sfh_name)
ax_sfr.set_xlabel("Lookback time [Gyr]")
ax_sfr.set_ylabel("SFR [M$_\\odot$ yr$^{-1}$]")
ax_sfr.grid(True, alpha=0.2, which="both")
ax_sfr.legend(loc="upper left", frameon=False)
# RIGHT: Rest-frame SED
ax_sed = fig.add_subplot(gs[row, 1])
sed = model.predict_rest_sed(truth)
wave_obs_um = np.asarray(sed.wavelength) * (1.0 + z) / 1e4
sed_fnu = np.asarray(units.lnu_to_fnu(sed.sed, dl_cm, z))
# Clip to visible window so log-autoscale doesn't include near-zero pixels
mask = (wave_obs_um >= 0.1) & (wave_obs_um <= 30)
ax_sed.loglog(wave_obs_um[mask], sed_fnu[mask], lw=2, color=color)
ax_sed.set_xlabel(r"Observed wavelength [$\mu$m]")
ax_sed.set_ylabel(r"$f_\nu$ [erg s$^{-1}$ cm$^{-2}$ Hz$^{-1}$]")
ax_sed.set_xlim(0.1, 30)
ymed = np.median(sed_fnu[mask & (sed_fnu > 0)])
ax_sed.set_ylim(ymed / 1e3, ymed * 30)
ax_sed.grid(True, alpha=0.2, which="both")
fig.suptitle("SFH families: SFR(t) and rest-frame SED", fontsize=12, y=0.995)
plt.savefig(str(_repo_root / "notebooks" / "figures" / "04_sfh_family_grid.png"), dpi=200, bbox_inches="tight")
plt.show()
W0507 04:35:21.238189 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:21.303259 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:21.390359 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
/Users/suchethacooray/Projects/tengri/src/tengri/forward/sed_model.py:636: BakedInNebularWarning: BakedInBackend: nebular emission is baked into the SSP file at a FIXED logU and FIXED escape fraction determined when the SSP grid was generated (commonly logU = −3, but depends on the SSP file). The ionization parameter and escape fraction are NOT free parameters — varying neb_logU or neb_fesc in your Parameters will have no effect. Check your SSP file's nebular assumptions. Switch to CloudyGridBackend or CueBackend to vary nebular properties. To suppress: pass ionizing_source_warning='suppress'.
self._nebular_backend = BakedInBackend()
W0507 04:35:22.586879 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:23.003733 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:24.852211 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
Vary the dust attenuation law¶
Keep the SFH fixed (tsnorm) and sweep dust attenuation law. The amount of attenuation and the detailed shape of the extinction curve affect the UV-to-optical ratio and the overall SED tilt.
[5]:
dust_laws = [
"calzetti",
"salim",
"smc",
"kriek_conroy",
"cardelli",
"noll09",
]
print("\nDust Law Comparison")
for dust_law in dust_laws:
spec = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Fixed(np.log10(15.0)),
sfh_tsnorm_peak_lbt_gyr=Fixed(3.0),
sfh_tsnorm_width_gyr=Fixed(2.5),
sfh_tsnorm_skew=Fixed(0.2),
sfh_tsnorm_trunc=Fixed(4.0),
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc=dust_law,
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(0.05),
apply_igm=False,
)
free_dust = [p for p in spec.free_params if p.startswith("dust_")]
print(f"{dust_law:20s} dust free params: {free_dust}")
Dust Law Comparison
======================================================================
calzetti dust free params: []
salim dust free params: []
smc dust free params: []
kriek_conroy dust free params: []
cardelli dust free params: []
noll09 dust free params: []
SED under different attenuation laws¶
Dust law choice affects the UV-optical tilt and features like the 2175 Å bump (present in Milky Way + starburst templates, absent in SMC-like laws). Left panel compares SEDs with no dust; right panel shows impact of each dust law on the intrinsic spectrum.
[6]:
fig = plt.figure(figsize=(14, 3.5))
gs = fig.add_gridspec(1, 2, wspace=0.3)
# Fixed SFH for all dust laws
truth_sfh = {
"sfh_tsnorm_log_peak_sfr": np.log10(15.0),
"sfh_tsnorm_peak_lbt_gyr": 3.0,
"sfh_tsnorm_width_gyr": 2.5,
"sfh_tsnorm_skew": 0.2,
"sfh_tsnorm_trunc": 4.0,
}
# LEFT: Intrinsic SED (zero dust) as reference
ax_ref = fig.add_subplot(gs[0])
spec_nodust = Parameters(
mean_sfh_type="tsnorm",
**{f"{k}": Fixed(v) for k, v in truth_sfh.items()},
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.0), # Zero attenuation
dust_tau_diff=Fixed(0.0),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(z),
apply_igm=False,
)
model_nodust = SEDModel(spec_nodust, ssp, observation=observation)
truth_nodust = {
**truth_sfh,
"met_logzsol": -0.1,
"dust_tau_bc": 0.0,
"dust_tau_diff": 0.0,
"dust_slope": -0.7,
"redshift": z,
}
sed_nodust = model_nodust.predict_rest_sed(truth_nodust)
wave_obs_um = np.asarray(sed_nodust.wavelength) * (1.0 + z) / 1e4
sed_fnu_nodust = np.asarray(units.lnu_to_fnu(sed_nodust.sed, dl_cm, z))
# Clip to visible window — keeps log-autoscale honest
_mask_ref = (wave_obs_um >= 0.1) & (wave_obs_um <= 30)
ax_ref.loglog(wave_obs_um[_mask_ref], sed_fnu_nodust[_mask_ref], lw=2.5,
color="black", label="Intrinsic (τ=0)")
ax_ref.set_xlabel(r"Observed wavelength [$\mu$m]")
ax_ref.set_ylabel(r"$f_\nu$ [erg s$^{-1}$ cm$^{-2}$ Hz$^{-1}$]")
ax_ref.set_xlim(0.1, 30)
_ymed_ref = np.median(sed_fnu_nodust[_mask_ref & (sed_fnu_nodust > 0)])
ax_ref.set_ylim(_ymed_ref / 1e3, _ymed_ref * 30)
ax_ref.grid(True, alpha=0.2, which="both")
ax_ref.legend(loc="upper right", frameon=False)
ax_ref.set_title("Intrinsic spectrum (no attenuation)")
# RIGHT: SEDs with each dust law
ax_sed = fig.add_subplot(gs[1])
for idx, dust_law in enumerate(dust_laws):
spec = Parameters(
mean_sfh_type="tsnorm",
**{f"{k}": Fixed(v) for k, v in truth_sfh.items()},
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc=dust_law,
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(z),
apply_igm=False,
)
model = SEDModel(spec, ssp, observation=observation)
truth = {
**truth_sfh,
"met_logzsol": -0.1,
"dust_tau_bc": 0.5,
"dust_tau_diff": 0.3,
"dust_slope": -0.7,
"redshift": z,
}
sed = model.predict_rest_sed(truth)
wave_obs_um = np.asarray(sed.wavelength) * (1.0 + z) / 1e4
sed_fnu = np.asarray(units.lnu_to_fnu(sed.sed, dl_cm, z))
color = plot.COLORS["seq"][idx % len(plot.COLORS["seq"])]
_m = (wave_obs_um >= 0.1) & (wave_obs_um <= 30)
ax_sed.loglog(wave_obs_um[_m], sed_fnu[_m], lw=2, label=dust_law, color=color)
if idx == 0:
_ymed_sed = np.median(sed_fnu[_m & (sed_fnu > 0)])
ax_sed.set_xlabel(r"Observed wavelength [$\mu$m]")
ax_sed.set_ylabel(r"$f_\nu$ [erg s$^{-1}$ cm$^{-2}$ Hz$^{-1}$]")
ax_sed.set_xlim(0.1, 30)
ax_sed.set_ylim(_ymed_sed / 1e3, _ymed_sed * 30)
ax_sed.grid(True, alpha=0.2, which="both")
ax_sed.legend(loc="upper right", frameon=False, fontsize=9)
ax_sed.set_title("Attenuated spectra (τ = 0.5)")
fig.suptitle("Dust attenuation law comparison", fontsize=12, y=1.00)
plt.savefig(str(_repo_root / "notebooks" / "figures" / "04_dust_law_grid.png"), dpi=200, bbox_inches="tight")
plt.show()
Vary the dust emission model¶
Three IR templates with different assumptions: empirical energy balance (Dale 2014), semi-analytic grain physics (DL07), and parametric blackbody.
[7]:
dust_emissions = [
"dale2014",
"draine_li2007",
"casey2012",
"modified_blackbody",
]
print("\nDust Emission Model Comparison")
for emission in dust_emissions:
try:
spec = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Fixed(np.log10(15.0)),
sfh_tsnorm_peak_lbt_gyr=Fixed(3.0),
sfh_tsnorm_width_gyr=Fixed(2.5),
sfh_tsnorm_skew=Fixed(0.2),
sfh_tsnorm_trunc=Fixed(4.0),
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission=emission,
redshift=Fixed(0.05),
apply_igm=False,
)
emission_params = [p for p in spec.free_params if p.startswith("dust_")]
print(f"{emission:20s} dust free params: {emission_params}")
except Exception as e:
print(f"{emission:20s} SKIPPED ({str(e)[:40]}...)")
Dust Emission Model Comparison
======================================================================
dale2014 dust free params: []
draine_li2007 dust free params: []
casey2012 dust free params: []
modified_blackbody dust free params: []
SED under different IR templates¶
Different IR templates (empirical Dale, semi-analytic DL07, observed Casey, parametric blackbody) produce different mid-to-far IR shapes. The energy balance check (L_IR / L_dust_absorbed ≈ 1) validates energy conservation.
[8]:
fig = plt.figure(figsize=(14, 3.5))
gs = fig.add_gridspec(1, 2, wspace=0.3)
# Fixed SFH/metallicity/dust for all emission models
truth_base = {
"sfh_tsnorm_log_peak_sfr": np.log10(15.0),
"sfh_tsnorm_peak_lbt_gyr": 3.0,
"sfh_tsnorm_width_gyr": 2.5,
"sfh_tsnorm_skew": 0.2,
"sfh_tsnorm_trunc": 4.0,
"met_logzsol": -0.1,
"dust_tau_bc": 0.5,
"dust_tau_diff": 0.3,
"dust_slope": -0.7,
"redshift": z,
}
# LEFT: SED templates
ax_sed = fig.add_subplot(gs[0])
for idx, emission in enumerate(dust_emissions):
spec = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Fixed(np.log10(15.0)),
sfh_tsnorm_peak_lbt_gyr=Fixed(3.0),
sfh_tsnorm_width_gyr=Fixed(2.5),
sfh_tsnorm_skew=Fixed(0.2),
sfh_tsnorm_trunc=Fixed(4.0),
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission=emission,
redshift=Fixed(z),
apply_igm=False,
)
model = SEDModel(spec, ssp, observation=observation)
sed = model.predict_rest_sed(truth_base)
wave_obs_um = np.asarray(sed.wavelength) * (1.0 + z) / 1e4
sed_fnu = np.asarray(units.lnu_to_fnu(sed.sed, dl_cm, z))
color = plot.COLORS["seq"][idx % len(plot.COLORS["seq"])]
_m_em = (wave_obs_um >= 0.1) & (wave_obs_um <= 1000) # widen for IR bump
ax_sed.loglog(wave_obs_um[_m_em], sed_fnu[_m_em], lw=2, label=emission, color=color)
if idx == 0:
_ymed_em = np.median(sed_fnu[_m_em & (sed_fnu > 0)])
ax_sed.set_xlabel(r"Observed wavelength [$\mu$m]")
ax_sed.set_ylabel(r"$f_\nu$ [erg s$^{-1}$ cm$^{-2}$ Hz$^{-1}$]")
ax_sed.set_xlim(0.1, 1000)
ax_sed.set_ylim(_ymed_em / 1e4, _ymed_em * 30)
ax_sed.grid(True, alpha=0.2, which="both")
ax_sed.legend(loc="upper right", frameon=False)
ax_sed.set_title("Rest-frame SED")
# RIGHT: Energy balance bar chart
ax_balance = fig.add_subplot(gs[1])
# Extract derived quantities (L_IR, L_dust_absorbed) from each model
l_ir_values = []
for emission in dust_emissions:
spec = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Fixed(np.log10(15.0)),
sfh_tsnorm_peak_lbt_gyr=Fixed(3.0),
sfh_tsnorm_width_gyr=Fixed(2.5),
sfh_tsnorm_skew=Fixed(0.2),
sfh_tsnorm_trunc=Fixed(4.0),
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission=emission,
redshift=Fixed(z),
apply_igm=False,
)
model = SEDModel(spec, ssp, observation=observation)
derived = model.predict_derived(truth_base)
l_ir = derived.get("L_ir_rest", 1.0) # Use fallback 1.0 if not available
l_ir_values.append(l_ir)
# Normalize by first value for visibility
l_ir_norm = np.array(l_ir_values) / l_ir_values[0]
colors = [plot.COLORS["seq"][i % len(plot.COLORS["seq"])] for i in range(len(dust_emissions))]
bars = ax_balance.bar(range(len(dust_emissions)), l_ir_norm,
color=colors, alpha=0.7, edgecolor="black", linewidth=1.2)
ax_balance.axhline(y=1.0, color="red", linestyle="--", linewidth=1.5, label="Energy balance (=1)")
ax_balance.set_ylabel(r"$L_{IR}$ (normalized)")
ax_balance.set_xticks(range(len(dust_emissions)))
ax_balance.set_xticklabels(dust_emissions, rotation=45, ha="right")
ax_balance.set_ylim(0.8, 1.2)
ax_balance.legend(loc="upper right", frameon=False, fontsize=9)
ax_balance.set_title("Energy conservation check")
ax_balance.grid(True, alpha=0.2, axis="y")
fig.suptitle("Dust IR emission model comparison", fontsize=12, y=1.00)
plt.savefig(str(_repo_root / "notebooks" / "figures" / "04_dust_emission_grid.png"), dpi=200, bbox_inches="tight")
plt.show()
Free vs fixed parameters¶
Same physical model, different parameter freedom. We demonstrate: free redshift vs fixed redshift, and free metallicity vs fixed.
[9]:
print("\nFree vs Fixed Parameter Tracking")
# Build a reference model to show summary()
spec_ref = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
sfh_tsnorm_skew=Uniform(-1.0, 1.0),
sfh_tsnorm_trunc=Uniform(1.0, 10.0),
met_logzsol=Uniform(-1.5, 0.3),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Uniform(0.0, 2.0),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Uniform(0.01, 0.1),
apply_igm=False,
)
print("\nModel Summary (using Parameters.summary_str()):")
print(spec_ref.summary_str())
# Model 1: free redshift
spec_free_z = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
sfh_tsnorm_skew=Uniform(-1.0, 1.0),
sfh_tsnorm_trunc=Uniform(1.0, 10.0),
met_logzsol=Uniform(-1.5, 0.3),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Uniform(0.0, 2.0),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Uniform(0.01, 0.1), # FREE
apply_igm=False,
)
# Model 2: fixed redshift
spec_fixed_z = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
sfh_tsnorm_skew=Uniform(-1.0, 1.0),
sfh_tsnorm_trunc=Uniform(1.0, 10.0),
met_logzsol=Uniform(-1.5, 0.3),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Uniform(0.0, 2.0),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(0.05), # FIXED
apply_igm=False,
)
print(f"Free redshift : {len(spec_free_z.free_params):2d} free params")
print(f" {spec_free_z.free_params}")
print()
print(f"Fixed redshift : {len(spec_fixed_z.free_params):2d} free params")
print(f" {spec_fixed_z.free_params}")
print()
print("Difference in free_params due to redshift:")
print(f" Free z has 'redshift': {'redshift' in spec_free_z.free_params}")
print(f" Fixed z has 'redshift': {'redshift' in spec_fixed_z.free_params}")
Free vs Fixed Parameter Tracking
======================================================================
Model Summary (using Parameters.summary_str()):
Parameters SFH: tsnorm
──────────────────────────────────────────────────────────────────
Dimensions: 9 free, + 22 fixed
Modules: dust_emission=dale2014, dust_law=calzetti/calzetti
Parameter Prior Bounds
────────────────────────────────────────────────────────────────
dust_tau_bc Uniform(0.0, 2.0) [0, 2]
dust_tau_diff Uniform(0.0, 1.5) [0, 1.5]
met_logzsol Uniform(-1.5, 0.3) [-1.5, 0.3]
redshift Uniform(0.01, 0.1) [0.01, 0.1]
sfh_tsnorm_log_peak_sfr Uniform(-1.0, 2.5) [-1, 2.5]
sfh_tsnorm_peak_lbt_gyr Uniform(0.5, 12.0) [0.5, 12]
sfh_tsnorm_skew Uniform(-1.0, 1.0) [-1, 1]
sfh_tsnorm_trunc Uniform(1.0, 10.0) [1, 10]
sfh_tsnorm_width_gyr Uniform(0.3, 5.0) [0.3, 5]
────────────────────────────────────────────────────────────────
dust_Rv Fixed 3.1
dust_T Fixed 35
dust_T_cold Fixed 20
dust_T_warm Fixed 45
dust_alpha_dale Fixed 2
dust_alpha_dl14 Fixed 2
dust_alpha_mir Fixed 2
dust_beta_ir Fixed 1.6
dust_bump_strength Fixed 0
dust_delta Fixed 0
dust_eta_balance Fixed 1
dust_f_obscuration Fixed 0
dust_gamma_dl Fixed 0.01
dust_log_ssfr Fixed -10
dust_qhac Fixed 0.17
dust_qpah Fixed 2.5
dust_slope Fixed -0.7
dust_umin Fixed 1
met_alpha_fe Fixed 0
noise_dof Fixed 0
noise_frac_cal Fixed 0
sigma_v_kms Fixed 0
──────────────────────────────────────────────────────────────────
Free redshift : 9 free params
['dust_tau_bc', 'dust_tau_diff', 'met_logzsol', 'redshift', 'sfh_tsnorm_log_peak_sfr', 'sfh_tsnorm_peak_lbt_gyr', 'sfh_tsnorm_skew', 'sfh_tsnorm_trunc', 'sfh_tsnorm_width_gyr']
Fixed redshift : 8 free params
['dust_tau_bc', 'dust_tau_diff', 'met_logzsol', 'sfh_tsnorm_log_peak_sfr', 'sfh_tsnorm_peak_lbt_gyr', 'sfh_tsnorm_skew', 'sfh_tsnorm_trunc', 'sfh_tsnorm_width_gyr']
Difference in free_params due to redshift:
Free z has 'redshift': True
Fixed z has 'redshift': False
Forward-model timing and sensitivity¶
JAX’s JIT compilation makes subsequent runs fast, and vmap enables vectorized predictions over many parameters. Here we time a single prediction vs N=50 sequential predictions to show scaling.
[10]:
# Build a representative model
spec_perf = Parameters(
mean_sfh_type="tsnorm",
sfh_tsnorm_log_peak_sfr=Fixed(np.log10(15.0)),
sfh_tsnorm_peak_lbt_gyr=Fixed(3.0),
sfh_tsnorm_width_gyr=Fixed(2.5),
sfh_tsnorm_skew=Fixed(0.2),
sfh_tsnorm_trunc=Fixed(4.0),
met_logzsol=Fixed(-0.1),
dust_model="two_component",
dust_law_bc="calzetti",
dust_tau_bc=Fixed(0.5),
dust_tau_diff=Fixed(0.3),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
redshift=Fixed(0.05),
apply_igm=False,
)
model_perf = SEDModel(spec_perf, ssp, observation=observation)
# Base truth dict
truth_perf = {
"sfh_tsnorm_log_peak_sfr": np.log10(15.0),
"sfh_tsnorm_peak_lbt_gyr": 3.0,
"sfh_tsnorm_width_gyr": 2.5,
"sfh_tsnorm_skew": 0.2,
"sfh_tsnorm_trunc": 4.0,
"met_logzsol": -0.1,
"dust_tau_bc": 0.5,
"dust_tau_diff": 0.3,
"dust_slope": -0.7,
"redshift": 0.05,
}
# Time: single call
t0 = time.perf_counter()
_ = model_perf.predict_photometry(truth_perf)
t_single = time.perf_counter() - t0
# Time: 50 sequential calls (with parameter variations)
n_iter = 50
t0 = time.perf_counter()
for i in range(n_iter):
truth_var = truth_perf.copy()
truth_var["dust_tau_bc"] = 0.5 + 0.01 * np.sin(i / 10.0)
_ = model_perf.predict_photometry(truth_var)
t_loop = time.perf_counter() - t0
print("\nForward Model Performance")
print(f"Single prediction: {t_single*1000:.2f} ms")
print(f"50 sequential predictions: {t_loop:.3f} s ({t_loop/n_iter*1000:.2f} ms per call)")
print(f"Per-call overhead (amortized):{(t_loop / n_iter - t_single)*1000:.2f} ms")
print()
print("Key lesson: Once compiled, calls are fast (~10-30 ms). Use sequential")
print("loops for sensitivity studies, or vmap() for full batch vectorization.")
/Users/suchethacooray/Projects/tengri/src/tengri/forward/sed_model.py:636: BakedInNebularWarning: BakedInBackend: nebular emission is baked into the SSP file at a FIXED logU and FIXED escape fraction determined when the SSP grid was generated (commonly logU = −3, but depends on the SSP file). The ionization parameter and escape fraction are NOT free parameters — varying neb_logU or neb_fesc in your Parameters will have no effect. Check your SSP file's nebular assumptions. Switch to CloudyGridBackend or CueBackend to vary nebular properties. To suppress: pass ionizing_source_warning='suppress'.
self._nebular_backend = BakedInBackend()
W0507 04:35:34.030659 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0507 04:35:34.793782 13033767 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
Forward Model Performance
======================================================================
Single prediction: 999.81 ms
50 sequential predictions: 0.909 s (18.18 ms per call)
Per-call overhead (amortized):-981.62 ms
Key lesson: Once compiled, calls are fast (~10-30 ms). Use sequential
loops for sensitivity studies, or vmap() for full batch vectorization.
Where to go next¶
Most of the API surface a model-builder needs is in this notebook: Parameters.with_params(**kw) for immutable swaps, Parameters.summary() for introspection, model.predict_sfh(...) and predict_photometry_batch(...) for diagnostics, tengri.describe(name) and tengri.cite(...) for registry and citations.
Natural next steps: `05_fitting_photometry <05_fitting_photometry.py>`__ runs a real fit and reads its posterior; `06_fitting_spectroscopy <06_fitting_spectroscopy.py>`__ breaks age, dust, and metallicity degeneracies with a spectrum. Stochastic SFHs live behind the sfh_field_psd_* parameters; AGN behind agn_disc= and agn_torus=. Build your own component, register it, and the parameter tracking and forward model pick it up unchanged.