Note
Go to the end to download the full example code.
Prior vs Posterior: Parameter Constraints from Inference¶
A Bayesian fit uses prior distributions over parameters and refines them using observed data to obtain posteriors. This script shows how priors (dashed lines) and posteriors (histograms) differ for key physical parameters (stellar mass age, metallicity, dust optical depth) after fitting mock photometry.
from pathlib import Path
import jax
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
# Already has Path imported
jax.config.update("jax_enable_x64", True)
from tengri import (
Fitter,
Fixed,
Observation,
Parameters,
Photometry,
SEDModel,
Uniform,
load_ssp_data,
setup_style,
)
setup_style()
def _find_ssp():
"""Locate SSP data from project root or docs/ (sphinx-gallery) cwd."""
name = "ssp_prsc_miles_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5"
for p in [
Path("data") / name,
Path("../data") / name,
Path("../../data") / name,
Path("../../../data") / name,
]:
if p.exists():
return str(p)
return None
SSP_PATH = _find_ssp()
if SSP_PATH is None:
raise FileNotFoundError("SSP data not found — skipping example")
ssp = load_ssp_data(SSP_PATH)
obs = Observation(
photometry=Photometry.from_names(["sdss_u", "sdss_g", "sdss_r", "sdss_i", "sdss_z"])
)
# --- Define priors (will be visualized as "prior" distributions) ---
spec = Parameters(
sfh_tsnorm_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_tsnorm_peak_lbt_gyr=Uniform(0.5, 12.0),
sfh_tsnorm_width_gyr=Uniform(0.3, 5.0),
met_logzsol=Uniform(-2.0, 0.2),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
redshift=Fixed(0.1),
mean_sfh_type="tsnorm",
)
model = SEDModel(spec, ssp, observation=obs)
# --- Generate mock data with known truth ---
key = jax.random.PRNGKey(99)
true_params = spec.sample(key)
# Override to ensure interesting star-forming galaxy
true_params["sfh_tsnorm_peak_lbt_gyr"] = 3.0
true_params["sfh_tsnorm_width_gyr"] = 2.0
true_params["sfh_tsnorm_log_peak_sfr"] = 1.0
true_params["met_logzsol"] = -0.2
true_params["dust_tau_diff"] = 0.5
mock = model.mock(true_params, snr=20.0, key=key)
# --- Fit ---
fitter = Fitter(model, mock.flux_obs, mock.noise, data_type="photometry")
fitter.run("map", n_steps=200, verbose=False)
fitter.compile(verbose=False)
posterior = fitter.run(
"vi",
n_iterations=8,
n_samples=3,
n_posterior_samples=1000,
verbose=False,
)
# --- Extract posterior samples and priors ---
posterior_samples = posterior.samples
param_names = list(posterior_samples.keys())
# Select key 3 parameters for comparison
selected_params = [
"sfh_tsnorm_peak_lbt_gyr",
"met_logzsol",
"dust_tau_diff",
]
param_labels = [
r"Age of peak SFR [Gyr]",
r"log(Z/Z$_\odot$)",
r"$\tau_{\rm diff}$",
]
# Prior ranges (from Uniform definitions in spec)
prior_ranges = [
(0.5, 12.0),
(-2.0, 0.2),
(0.0, 1.5),
]
fig, axes = plt.subplots(1, 3, figsize=(13, 4))
for i, (param_name, param_label, (prior_min, prior_max)) in enumerate(
zip(selected_params, param_labels, prior_ranges)
):
ax = axes[i]
if param_name in posterior_samples:
samples = np.array(posterior_samples[param_name])
# Plot prior as uniform horizontal line over the prior support.
prior_density = 1.0 / (prior_max - prior_min)
ax.plot(
[prior_min, prior_max],
[prior_density, prior_density],
color="gray",
lw=2.0,
ls="--",
label="Prior (uniform)",
)
# Plot posterior histogram
ax.hist(
samples,
bins=30,
color="C0",
alpha=0.6,
density=True,
label="Posterior",
edgecolor="C0",
linewidth=1.0,
)
# Mark truth value
if param_name in true_params:
truth_val = true_params[param_name]
ax.axvline(truth_val, color="red", lw=2.0, ls="-", alpha=0.7, label="Truth")
ax.set_xlabel(param_label, fontsize=11)
ax.set_ylabel("Probability density", fontsize=10)
ax.legend(fontsize=9, frameon=False)
ax.grid(True, alpha=0.3, axis="y")
fig.suptitle(
"Prior vs Posterior: Parameter Constraints from Mock Photometry Fit",
fontsize=13,
)
fig.tight_layout()
# Save to script directory
script_dir = Path(__file__).resolve().parent if "__file__" in dir() else Path(".")
plt.savefig(str(script_dir / "plot_prior_posterior_compare.png"), dpi=150, bbox_inches="tight")
plt.close()