Note
Go to the end to download the full example code.
SFH Recovery Across Four Burstiness Regimes¶
Four burstiness regimes — Smooth, Moderate, Bursty, Extreme — each defined by the PSD amplitude σ and correlation time τ. Forward-model SFH draws show the range of histories each regime produces before inference.
from pathlib import Path
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
jax.config.update("jax_enable_x64", True)
from tengri import (
Fixed,
Observation,
Parameters,
SEDModel,
Spectroscopy,
Uniform,
load_ssp_data,
setup_style,
)
setup_style()
def _find_ssp():
name = "ssp_prsc_miles_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
for p in [
Path("data") / name,
Path("../data") / name,
Path("../../data") / name,
Path("../../../data") / name,
]:
if p.exists():
return str(p)
return None
SSP_PATH = _find_ssp()
if SSP_PATH is None:
raise FileNotFoundError("SSP data not found — skipping example")
ssp = load_ssp_data(SSP_PATH)
wave_obs = jnp.linspace(3800.0, 9200.0, 200)
obs = Observation(spectroscopy=Spectroscopy(wave_obs=wave_obs))
spec = Parameters(
sfh_tsnorm_log_peak_sfr=Fixed(1.2),
sfh_tsnorm_peak_lbt_gyr=Fixed(3.0),
sfh_tsnorm_width_gyr=Fixed(3.0),
sfh_tsnorm_skew=Fixed(0.3),
sfh_tsnorm_trunc=Fixed(2.0),
sfh_field_psd_sigma=Uniform(0.1, 4.0),
sfh_field_psd_tau_myr=Uniform(1.0, 300.0),
met_logzsol=Fixed(-0.3),
dust_tau_bc=Fixed(0.3),
dust_tau_diff=Fixed(0.2),
dust_slope=Fixed(-0.7),
redshift=Fixed(0.0),
mean_sfh_type=["tsnorm", "field"],
)
model = SEDModel(spec, ssp, observation=obs)
REGIMES = [
{"label": "Smooth", "sigma": 0.3, "tau": 100.0, "color": "#1f77b4"},
{"label": "Moderate", "sigma": 1.0, "tau": 50.0, "color": "#ff7f0e"},
{"label": "Bursty", "sigma": 2.0, "tau": 20.0, "color": "#2ca02c"},
{"label": "Extreme", "sigma": 3.0, "tau": 5.0, "color": "#d62728"},
]
fig, axes = plt.subplots(2, 2, figsize=(10, 6), sharey=False)
axes_flat = axes.flatten()
for ax, reg in zip(axes_flat, REGIMES):
key = jax.random.PRNGKey(42)
params = {
**spec.sample(key),
"sfh_field_psd_sigma": jnp.array(reg["sigma"]),
"sfh_field_psd_tau_myr": jnp.array(reg["tau"]),
}
sfh = model.predict_sfh(params)
t_gyr = np.array(sfh["t_gyr"])
sfr_full = np.array(sfh["sfr_full"])
sfr_mean = np.array(sfh["sfr_mean"])
ax.fill_between(t_gyr, 0, sfr_full, alpha=0.4, color=reg["color"])
ax.plot(t_gyr, sfr_full, color=reg["color"], lw=1.2)
ax.plot(t_gyr, sfr_mean, color="k", lw=0.8, ls="--", alpha=0.5, label="Mean SFH")
ax.set_title(f"{reg['label']}: σ={reg['sigma']}, τ={reg['tau']:.0f} Myr")
ax.set_xlabel("Lookback time [Gyr]")
ax.set_ylabel(r"SFR [$M_\odot$ yr$^{-1}$]")
ax.set_xlim(0, 13)
ax.set_ylim(bottom=0)
fig.suptitle("Four Burstiness Regimes: IFT PSD Prior Draws", fontsize=12, y=1.02)
fig.tight_layout()
plt.savefig("plot_bursty_recovery.png", dpi=150, bbox_inches="tight")
plt.show()