Photometric SED Fit

Generate a mock galaxy with SDSS ugriz photometry and fit it using tengri’s variational inference. Shows observed vs model photometry with error bars and residuals.

plot_photometric_fit
from pathlib import Path

import jax
import matplotlib.pyplot as plt
import numpy as np

from tengri import (
    Fitter,
    Fixed,
    Observation,
    Parameters,
    Photometry,
    SEDModel,
    Uniform,
    load_ssp_data,
)
from tengri.analysis.plotting import setup_style

setup_style()


# --- Check for SSP data ---
def _find_ssp():
    """Locate SSP data from project root or docs/ (sphinx-gallery) cwd."""
    name = "ssp_prsc_miles_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
    for p in [
        Path("data") / name,
        Path("../data") / name,
        Path("../../data") / name,
        Path("../../../data") / name,
    ]:
        if p.exists():
            return str(p)
    return None


SSP_PATH = _find_ssp()

# Locate filter cache
_FILTER_DIR = next(
    (
        str(d)
        for d in [
            Path("data/filters"),
            Path("../data/filters"),
            Path("../../data/filters"),
            Path("../../../data/filters"),
        ]
        if d.exists()
    ),
    "data/filters",
)
if SSP_PATH is None:
    raise FileNotFoundError("SSP data not found — skipping example")

# --- Setup ---
ssp_data = load_ssp_data(SSP_PATH)
bands = ["sdss_u", "sdss_g", "sdss_r", "sdss_i", "sdss_z"]
obs = Observation(
    photometry=Photometry.from_names(bands, cache_dir=_FILTER_DIR),
)

spec = Parameters(
    sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
    sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
    sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
    sfh_tsnorm_skew=Uniform(-3.0, 3.0),
    sfh_tsnorm_trunc=Uniform(1.0, 10.0),
    met_logzsol=Uniform(-2.0, 0.2),
    dust_tau_bc=Fixed(0.3),
    dust_tau_diff=Fixed(0.2),
    dust_slope=Fixed(-0.7),
    redshift=Fixed(0.05),
)
model = SEDModel(spec, ssp_data, observation=obs)

# --- Generate mock data (star-forming galaxy) ---
true_params = spec.sample(jax.random.PRNGKey(42))
true_params["sfh_tsnorm_peak_lbt_gyr"] = 3.0
true_params["sfh_tsnorm_width_gyr"] = 2.0
true_params["sfh_tsnorm_log_peak_sfr"] = 1.0
true_params["sfh_tsnorm_skew"] = 0.3  # Positive skew = recent star formation
mock = model.mock(true_params, snr=20.0, key=jax.random.PRNGKey(0))

# --- Fit with MAP ---
fitter = Fitter(model, mock.flux_obs, mock.noise)
posterior = fitter.run("map", optimizer="adam", n_steps=300, verbose=False)
best_fit = model.predict_photometry(posterior.params)

# --- Plot ---
wave_eff = np.array([3551, 4686, 6166, 7480, 8932])  # SDSS effective wavelengths
band_names = ["u", "g", "r", "i", "z"]

fig, (ax, ax_res) = plt.subplots(
    2, 1, figsize=(7, 5), height_ratios=[3, 1], sharex=True, gridspec_kw={"hspace": 0.05}
)

ax.errorbar(
    wave_eff,
    np.array(mock.flux_obs),
    yerr=np.array(mock.noise),
    fmt="o",
    color="0.3",
    ms=6,
    capsize=3,
    label="Observed",
    zorder=5,
)
ax.scatter(
    wave_eff,
    np.array(mock.flux_true),
    marker="s",
    s=50,
    facecolors="none",
    edgecolors="C0",
    lw=1.5,
    label="Truth",
    zorder=4,
)
ax.scatter(wave_eff, np.array(best_fit), marker="D", s=40, color="C3", label="MAP fit", zorder=6)
ax.set_ylabel(r"$f_\nu$ [arbitrary]")
ax.legend(frameon=False)
ax.set_title("SDSS Photometric Fit (MAP)")

residuals = (np.array(mock.flux_obs) - np.array(best_fit)) / np.array(mock.noise)
ax_res.axhline(0, color="0.5", ls="--", lw=0.8)
ax_res.scatter(wave_eff, residuals, c="C3", s=30, zorder=5)
ax_res.set_xlabel(r"Wavelength [$\AA$]")
ax_res.set_ylabel(r"$(f_\mathrm{obs} - f_\mathrm{mod}) / \sigma$")
ax_res.set_ylim(-4, 4)
ax_res.set_xticks(wave_eff)
ax_res.set_xticklabels(band_names)

fig.tight_layout()
plt.savefig("plot_photometric_fit.png", dpi=150, bbox_inches="tight")
plt.show()

Gallery generated by Sphinx-Gallery