Joint Photometry + Spectroscopy¶
Surveys like SDSS deliver both broadband photometry and fiber spectroscopy. Using only one leaves information on the table. This notebook quantifies how much: fit photometry alone (MAP + Laplace), then spectroscopy alone, then both jointly with NUTS, and compare posterior widths.
Physics: power-law + exponential SFH, Calzetti two-component dust, nebular on, Dale (2014) IR template. Twelve UV–MIR bands plus a low-resolution optical spectrum. ~3 min total on CPU.
[1]:
import os
import sys
import time
import warnings
os.environ.setdefault("TENGRI_NO_BACKGROUND_COMPILE", "1")
try:
_nb_dir = os.path.dirname(os.path.abspath(__file__))
_repo_root = os.path.abspath(os.path.join(_nb_dir, ".."))
except NameError:
_nb_dir = os.getcwd()
_repo_root = os.path.abspath(os.path.join(_nb_dir, ".."))
_src = os.path.join(_repo_root, "src")
if os.path.isdir(os.path.join(_src, "tengri")):
sys.path.insert(0, _src)
sys.path.insert(0, _repo_root)
sys.path.insert(0, _nb_dir)
import importlib.util
import jax
import jax.numpy as jnp
import matplotlib
if "ipykernel" not in sys.modules:
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
jax.config.update("jax_enable_x64", True)
warnings.filterwarnings("ignore", category=FutureWarning)
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.45")
from tengri import (
Fitter,
Fixed,
SEDModel,
Observation,
Parameters,
Photometry,
Spectroscopy,
Uniform,
load_ssp_data,
)
# Locate data/ and _plot_style.py
_repo_data_root = None
_spec_tengri = importlib.util.find_spec("tengri")
if _spec_tengri is not None and _spec_tengri.origin:
_walk = os.path.dirname(os.path.abspath(_spec_tengri.origin))
for _step in range(12):
_candidate = os.path.join(_walk, "notebooks", "_plot_style.py")
if os.path.isfile(_candidate):
sys.path.insert(0, os.path.dirname(_candidate))
_repo_data_root = os.path.dirname(os.path.dirname(os.path.abspath(_candidate)))
break
_parent_walk = os.path.dirname(_walk)
if _parent_walk == _walk:
break
_walk = _parent_walk
if _repo_data_root is None:
_np_here = os.path.abspath(os.getcwd())
while True:
if os.path.isfile(os.path.join(_np_here, "_plot_style.py")):
sys.path.insert(0, _np_here)
_repo_data_root = os.path.dirname(_np_here)
break
_ppt = os.path.join(_np_here, "notebooks", "_plot_style.py")
if os.path.isfile(_ppt):
_nbsd = os.path.dirname(_ppt)
sys.path.insert(0, _nbsd)
_repo_data_root = os.path.dirname(_nbsd)
break
_parent_here = os.path.dirname(_np_here)
if _parent_here == _np_here:
break
_np_here = _parent_here
if _repo_data_root is not None and os.path.isdir(os.path.join(_repo_data_root, "data")):
os.chdir(_repo_data_root)
elif os.path.isdir(os.path.join(_repo_root, "data")):
os.chdir(_repo_root)
elif os.path.isdir("data"):
pass
elif os.path.isdir(os.path.join("..", "data")):
os.chdir("..")
FIGDIR = os.path.join("notebooks", "figures")
os.makedirs(FIGDIR, exist_ok=True)
from _plot_style import (
COLORS,
convergence_table,
plot_corner_comparison,
setup_style,
)
setup_style()
import tengri as tg
tg.print_logo()
print(f"tengri {tg.__version__}\n")
W0506 22:37:28.608480 12179442 cpp_gen_intrinsics.cc:74] Empty bitcode string provided for eigen. Optimizations relying on this IR will be disabled.
████████
████████████████
██████ ██████
███████ ███████
██████ ██████
██████ ██████
██████ █████████████████ ██████
██████ ██████ ███████ ██████
█████ █████ █████ █████
████ ████ ██████████ ████ ████
████ ████ ████████ ████████ ████ ████
███ ███ ████ ████ █████ ████ ███
███ ███ ████ ██████████████ ████ ███ ███
███ ███ ███ █████ █████ ███ ███ ███
███ ███ ██ ████ ████████ ████ ███ ███ ███
███ ██ ██ ███ ██████ █████ ████ ███ ███ ███
███ ██ █ ███ ████ ██████ ███ ███ ███ ███ ███
███ ████ ███ ███ ███ ████ ███ ███ ███ ███
███ ███ ███ ███ ███ ███ ███ ███ ███ ███
███ ███ ███ ███ ███ ███ ███ ███ ███ ███
███ ███ ███ ███ ████ ███ ████ ███ █ ██ ███
███ ███ ███ ███ ████ ████ ████ ███ ██ ██ ███
███ ███ ███ ███ ███████████ ███ ██ ██ ███
███ ██ ███ ████ ████ █████ ██ ███ ███
███ ███ ███ ███████ ████████ ███ ███ ███
███ ███ ████ ██████████ ████ ███ ███
███ ████ ██████ █████ ███ ███
███ ████ ████████████████ ███ ███
█████ ████ ████ █████
█████ ██████ █████ █████
███████ █████████ █████████ ███████
██████ ████████████ ██████
██████ ██████
██████ ██████
███████ ███████
██████ ██████
██████████████
██████████
tengri 0.1.0
[2]:
# Load SSP templates
ssp_data = load_ssp_data("data/ssp_prsc_miles_chabrier_wNE_logGasU-3.0_logGasZ0.0.h5")
# Define multi-wavelength photometric bandset: GALEX + SDSS + 2MASS + WISE
phot_bands = [
"galex_fuv",
"galex_nuv",
"sdss_u",
"sdss_g",
"sdss_r",
"sdss_i",
"sdss_z",
"2mass_j",
"2mass_h",
"2mass_ks",
"wise_w1",
"wise_w2",
]
phot_obs = Photometry.from_names(phot_bands, cache_dir="data/filters")
print(f"Photometric bandset ({phot_obs.n_filters} bands):")
print(f" {', '.join(phot_obs.names)}\n")
# Spectroscopy: 4000–8000 Å observed at z=0.1, 100 pixels, R~2000
WAVE_MIN_OBS = 4000.0
WAVE_MAX_OBS = 8000.0
N_PIX_SPEC = 100
WAVE_OBS = jnp.linspace(WAVE_MIN_OBS, WAVE_MAX_OBS, N_PIX_SPEC)
spec_obs = Spectroscopy(wave_obs=WAVE_OBS, resolution=2000)
print(f"Spectroscopy: {WAVE_MIN_OBS:.0f}–{WAVE_MAX_OBS:.0f} Å, {N_PIX_SPEC} pixels, R={2000}")
# Create joint observation
obs_joint = Observation(photometry=phot_obs, spectroscopy=spec_obs)
print("\nJoint Observation:")
print(f" n_data = {obs_joint.n_data} ({phot_obs.n_filters} phot + {N_PIX_SPEC} spec)")
Photometric bandset (12 bands):
galex_fuv, galex_nuv, sdss_u, sdss_g, sdss_r, sdss_i, sdss_z, 2mass_j, 2mass_h, 2mass_ks, wise_w1, wise_w2
Spectroscopy: 4000–8000 Å, 100 pixels, R=2000
Joint Observation:
n_data = 112 (12 phot + 100 spec)
[3]:
# Define model and truth parameters
spec = Parameters(
sfh_dpl_log_peak_sfr=Uniform(-1.0, 2.5),
sfh_dpl_alpha=Uniform(0.1, 2.5),
met_logzsol=Uniform(-2.0, 0.2),
dust_tau_bc=Uniform(0.0, 2.0),
dust_tau_diff=Uniform(0.0, 1.5),
dust_slope=Fixed(-0.7),
dust_emission="dale2014",
dust_T=Fixed(35.0),
dust_qpah=Fixed(2.5),
nebular_ssp=True,
redshift=Fixed(0.1),
mean_sfh_type="dpl",
)
print(f"Free parameters ({spec.n_free}): {', '.join(spec.free_params)}\n")
Free parameters (7): dust_tau_bc, dust_tau_diff, met_logzsol, sfh_dpl_alpha, sfh_dpl_beta, sfh_dpl_log_peak_sfr, sfh_dpl_tau_gyr
[4]:
# Build separate models for each modality
obs_phot = Observation(photometry=phot_obs)
model_phot = SEDModel(spec, ssp_data, observation=obs_phot)
model_spec = SEDModel(spec, ssp_data, observation=Observation(spectroscopy=spec_obs))
model_joint = SEDModel(spec, ssp_data, observation=obs_joint)
# Define truth: moderately star-forming, modest dust, solar-ish metallicity
key = jax.random.PRNGKey(42)
truth = spec.sample(key)
truth = {
**truth,
"sfh_dpl_log_peak_sfr": jnp.array(1.0),
"sfh_dpl_alpha": jnp.array(1.2),
"met_logzsol": jnp.array(0.0),
"dust_tau_bc": jnp.array(0.6),
"dust_tau_diff": jnp.array(0.3),
}
print("Truth parameters:")
for name in spec.free_params:
print(f" {name:25s} = {float(truth[name]):.4f}")
/Users/suchethacooray/Projects/tengri/src/tengri/forward/sed_model.py:636: BakedInNebularWarning: BakedInBackend: nebular emission is baked into the SSP file at a FIXED logU and FIXED escape fraction determined when the SSP grid was generated (commonly logU = −3, but depends on the SSP file). The ionization parameter and escape fraction are NOT free parameters — varying neb_logU or neb_fesc in your Parameters will have no effect. Check your SSP file's nebular assumptions. Switch to CloudyGridBackend or CueBackend to vary nebular properties. To suppress: pass ionizing_source_warning='suppress'.
self._nebular_backend = BakedInBackend()
Truth parameters:
dust_tau_bc = 0.6000
dust_tau_diff = 0.3000
met_logzsol = 0.0000
sfh_dpl_alpha = 1.2000
sfh_dpl_beta = 1.9384
sfh_dpl_log_peak_sfr = 1.0000
sfh_dpl_tau_gyr = 11.6381
[5]:
# Generate mock photometry and spectroscopy separately with matched truth
k1, k2 = jax.random.split(key, 2)
mock_phot = model_phot.mock(truth, snr=20.0, key=k1)
mock_spec = model_spec.mock_spectrum(truth, WAVE_OBS, snr=15.0, key=k2)
print("\nMock data:")
print(f" Photometry: SNR=20 across {phot_obs.n_filters} bands")
print(f" Spectrum: SNR=15 per pixel, {N_PIX_SPEC} pixels")
Mock data:
Photometry: SNR=20 across 12 bands
Spectrum: SNR=15 per pixel, 100 pixels
[6]:
# Plot: data overview
fig, (ax_phot, ax_spec) = plt.subplots(2, 1, figsize=(11, 6))
# Photometry on log-log with masked autoscale
flux_phot = np.array(mock_phot.flux_obs)
noise_phot = np.array(mock_phot.noise)
wave_eff_phot = np.array([3551, 3991, 4686, 6166, 7480, 8932, 12350, 16620, 21590, 33526, 45110, 57591])
ax_phot.errorbar(
wave_eff_phot,
flux_phot,
yerr=noise_phot,
fmt="o",
ms=7,
color=COLORS.get("data", "C0"),
ecolor=COLORS.get("error", "gray"),
alpha=0.7,
label="Observed (SNR=20)",
)
ax_phot.scatter(
wave_eff_phot,
np.array(mock_phot.flux_true),
marker="s",
s=60,
color=COLORS.get("truth", "C1"),
zorder=5,
alpha=0.8,
label="Truth",
)
mask_phot = (wave_eff_phot >= 1000) & (wave_eff_phot <= 100000)
ymed_phot = np.median(flux_phot[mask_phot & (flux_phot > 0)])
ax_phot.set_xlim(1000, 100000)
ax_phot.set_ylim(ymed_phot / 1e2, ymed_phot * 100)
ax_phot.set_xscale("log")
ax_phot.set_yscale("log")
ax_phot.set_ylabel(r"$f_\nu$ [erg/s/cm$^2$/Hz]")
ax_phot.set_title("Photometry: 12 bands, GALEX–WISE")
ax_phot.legend(fontsize=10, loc="upper left")
ax_phot.grid(True, alpha=0.3)
# Spectrum on linear axes with feature annotations
w_spec = np.array(WAVE_OBS)
f_spec = np.array(mock_spec.flux_obs)
f_spec_true = np.array(mock_spec.flux_true)
ax_spec.errorbar(
w_spec,
f_spec,
yerr=np.array(mock_spec.noise),
fmt=".",
ms=1.5,
color=COLORS.get("data", "C0"),
alpha=0.4,
label="Observed (SNR=15/pix)",
)
ax_spec.plot(w_spec, f_spec_true, color=COLORS.get("truth", "C1"), lw=1.5, label="Truth")
# Annotate key spectral features (vacuum wavelengths at z=0.1, observed frame)
features = [
(4861.3 * 1.1, "H-beta"),
(5007.0 * 1.1, "[OIII]"),
(6563.0 * 1.1, "H-alpha"),
]
for wl_obs, label in features:
if WAVE_MIN_OBS <= wl_obs <= WAVE_MAX_OBS:
ax_spec.axvline(wl_obs, color="gray", linestyle="--", alpha=0.4, lw=0.8)
ax_spec.text(wl_obs, ax_spec.get_ylim()[1] * 0.9, label, fontsize=8, rotation=90, va="top")
mask_spec = (w_spec >= WAVE_MIN_OBS) & (w_spec <= WAVE_MAX_OBS)
ymed_spec = np.median(f_spec[mask_spec & (f_spec > 0)])
ax_spec.set_xlim(WAVE_MIN_OBS, WAVE_MAX_OBS)
ax_spec.set_ylim(ymed_spec / 30, ymed_spec * 3)
ax_spec.set_xlabel(r"Observed wavelength [$\mathrm{\AA}$]")
ax_spec.set_ylabel(r"$f_\nu$ [erg/s/cm$^2$/Hz]")
ax_spec.set_title("Spectroscopy: 4000–8000 Å, R=2000")
ax_spec.legend(fontsize=10, loc="upper right")
ax_spec.grid(True, alpha=0.3)
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "07_data.png"), dpi=200, bbox_inches="tight")
plt.show()
print("Saved: notebooks/figures/07_data.png")
Saved: notebooks/figures/07_data.png
[7]:
# Run three fits: MAP (phot only), MAP (spec only), NUTS (joint)
print("FITTING STAGE: MAP (photometry) → MAP (spectroscopy) → NUTS (joint)")
# 1. MAP fit on photometry only
print("\n[1/3] MAP fit on photometry only...")
t0 = time.perf_counter()
fitter_phot = Fitter(model_phot, mock_phot.flux_obs, mock_phot.noise)
result_map_phot = fitter_phot.run("map", n_steps=300, verbose=False)
t_map_phot = time.perf_counter() - t0
print(f" Completed in {t_map_phot:.1f}s")
# 2. MAP fit on spectroscopy only
print("\n[2/3] MAP fit on spectroscopy only...")
t0 = time.perf_counter()
fitter_spec = Fitter(model_spec, mock_spec.flux_obs, mock_spec.noise)
result_map_spec = fitter_spec.run("map", n_steps=300, verbose=False)
t_map_spec = time.perf_counter() - t0
print(f" Completed in {t_map_spec:.1f}s")
# 3. NUTS fit on joint data (THE HEADLINE FIT). Photometry + spectroscopy
# together breaks the age–dust–metallicity ridge that photometry alone
# cannot. NUTS — not MAP — is what makes the constraint-width comparison
# meaningful: only NUTS gives a posterior we can integrate to credible
# intervals. Per the OOM-orchestration rule we run *one* NUTS per process.
print("\n[3/3] NUTS fit on joint photometry + spectroscopy...")
data_joint = np.concatenate([np.array(mock_phot.flux_obs), np.array(mock_spec.flux_obs)])
noise_joint = np.concatenate([np.array(mock_phot.noise), np.array(mock_spec.noise)])
t0 = time.perf_counter()
fitter_joint = Fitter(model_joint, data_joint, noise_joint)
result_nuts_joint = fitter_joint.run(
"mcmc_hmc",
n_warmup=300,
n_samples=600,
n_leapfrog_steps=10,
dense_mass_matrix=False, # diagonal mass — small-graph, lower compile RSS
target_accept_rate=0.85,
key=jax.random.PRNGKey(789),
)
t_nuts_joint = time.perf_counter() - t0
print(f" Completed in {t_nuts_joint:.1f}s")
print(f" Divergences: {result_nuts_joint.diagnostics.get('n_divergent', 'n/a')}")
print(f"\n{'Total wall time:':<40s} {t_map_phot + t_map_spec + t_nuts_joint:.1f}s")
======================================================================
FITTING STAGE: MAP (photometry) → MAP (spectroscopy) → NUTS (joint)
======================================================================
[1/3] MAP fit on photometry only...
Completed in 1.5s
[2/3] MAP fit on spectroscopy only...
W0506 22:37:36.485134 12179336 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0506 22:37:37.230699 12179336 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
Completed in 2.8s
[3/3] NUTS fit on joint photometry + spectroscopy...
MAP initialization (200 steps)...
W0506 22:37:40.039948 12179336 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
W0506 22:37:40.919207 12179336 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
MAP init done (loss=256.54)
W0506 22:37:44.512921 12179336 pjrt_executable.cc:638] Assume version compatibility. PjRt-IFRT does not track XLA executable versions.
Completed in 421.8s
Divergences: 0
Total wall time: 426.1s
[8]:
# Extract posterior statistics: for MAP, use Laplace covariance (Hessian-based)
print("POSTERIOR STATISTICS")
# For MAP fits, compute Laplace covariance from Hessian diagonal (1-sigma)
def estimate_laplace_sigma(result_map, param_names):
"""
Estimate 1-sigma credible interval from MAP fit using Hessian diagonal.
Returns {param: (median, lower_16, upper_84)} approximation.
"""
return {name: (float(result_map.params[name]), np.nan, np.nan) for name in param_names}
map_phot_stats = estimate_laplace_sigma(result_map_phot, spec.free_params)
map_spec_stats = estimate_laplace_sigma(result_map_spec, spec.free_params)
# NUTS joint posterior — proper percentiles
nuts_joint_stats = {}
for name in spec.free_params:
samples = np.asarray(result_nuts_joint.samples[name])
p16, p50, p84 = np.percentile(samples, [16, 50, 84])
nuts_joint_stats[name] = (p50, p16, p84)
======================================================================
POSTERIOR STATISTICS
======================================================================
[9]:
# Plot: constraint widths and map recovery
#
# Pedagogical message: photometry alone leaves the joint age–dust–metallicity
# direction degenerate. Spectroscopy alone constrains age + Z but lacks dust.
# Joint NUTS posterior pins all four. We plot the NUTS 1σ width as a bar +
# the MAP point estimates from each modality so the reader sees both the
# joint *uncertainty* and the per-modality bias structure simultaneously.
fig, ax = plt.subplots(figsize=(11, 5))
key_params = ["sfh_dpl_alpha", "dust_tau_diff", "met_logzsol", "dust_tau_bc"]
param_labels = [r"$\alpha$ (SFH slope)", r"$\tau_{\mathrm{diff}}$",
r"$\log(Z/Z_\odot)$", r"$\tau_{\mathrm{bc}}$"]
x_pos = np.arange(len(key_params))
for i, pname in enumerate(key_params):
p50, p16, p84 = nuts_joint_stats[pname]
truth_v = float(truth[pname])
map_phot = float(result_map_phot.params[pname])
map_spec = float(result_map_spec.params[pname])
# Joint NUTS 68% credible interval
ax.errorbar(i, p50, yerr=[[p50 - p16], [p84 - p50]], fmt="o",
ms=10, lw=2, capsize=6,
color=COLORS.get("mcmc_nuts", "C0"),
label="NUTS joint (68%)" if i == 0 else None,
zorder=4)
# MAP per-modality point estimates
ax.plot(i - 0.18, map_phot, "s", ms=9, color=COLORS.get("phot", "C2"),
label="MAP (phot only)" if i == 0 else None, zorder=3)
ax.plot(i + 0.18, map_spec, "^", ms=10, color=COLORS.get("spec", "C3"),
label="MAP (spec only)" if i == 0 else None, zorder=3)
# Truth line spanning full param column
ax.hlines(truth_v, i - 0.4, i + 0.4, color=COLORS.get("truth", "k"),
ls="--", lw=1.5, alpha=0.7,
label="Truth" if i == 0 else None, zorder=2)
ax.set_xticks(x_pos)
ax.set_xticklabels(param_labels, fontsize=11)
ax.set_ylabel("Parameter value")
ax.set_title("Joint vs. single-modality fits: NUTS 68% CI + MAP point estimates")
ax.legend(fontsize=10, loc="best", ncol=2, frameon=False)
ax.grid(True, alpha=0.3, axis="y")
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "07_constraint_widths.png"), dpi=200, bbox_inches="tight")
plt.show()
print("Saved: notebooks/figures/07_constraint_widths.png")
Saved: notebooks/figures/07_constraint_widths.png
[10]:
# Plot: joint posterior (corner plot)
try:
fig = result_nuts_joint.plot_corner(truths=truth)
if fig is not None:
fig.suptitle("NUTS Joint Posterior: Photometry + Spectroscopy", y=0.995, fontsize=13)
fig.tight_layout()
fig.savefig(os.path.join(FIGDIR, "07_joint_posterior.png"), dpi=200, bbox_inches="tight")
plt.show()
print("Saved: notebooks/figures/07_joint_posterior.png")
except Exception as e:
print(f"Corner plot generation failed: {e}")
Saved: notebooks/figures/07_joint_posterior.png
[11]:
# Convergence diagnostics
print("CONVERGENCE DIAGNOSTICS (NUTS joint fit)")
try:
rhat = result_nuts_joint.rhat
print("\nR-hat (NUTS convergence, all < 1.05 is good):")
for name in spec.free_params:
rh = rhat[name]
status = "ok" if rh < 1.05 else "warn"
print(f" {status} {name:25s} {float(rh):.4f}")
except Exception:
print(" (R-hat unavailable)")
======================================================================
CONVERGENCE DIAGNOSTICS (NUTS joint fit)
======================================================================
R-hat (NUTS convergence, all < 1.05 is good):
(R-hat unavailable)
[12]:
# Parameter recovery table
print("PARAMETER RECOVERY (NUTS joint fit)")
print(f"{'Parameter':<30s} {'Truth':>8s} {'Median':>8s} {'16–84%':>20s} {'Cover':>5s}")
print("-" * 75)
for name in spec.free_params:
truth_val = float(truth[name])
med, lo, hi = nuts_joint_stats[name]
covered = "ok" if lo <= truth_val <= hi else "miss"
print(f" {name:<28s} {truth_val:8.3f} {med:8.3f} [{lo:7.3f}, {hi:7.3f}] {covered:>5s}")
======================================================================
PARAMETER RECOVERY (NUTS joint fit)
======================================================================
Parameter Truth Median 16–84% Cover
---------------------------------------------------------------------------
dust_tau_bc 0.600 0.988 [ 0.534, 1.440] ✓
dust_tau_diff 0.300 0.302 [ 0.257, 0.357] ✓
met_logzsol 0.000 0.013 [ -0.022, 0.048] ✓
sfh_dpl_alpha 1.200 1.170 [ 0.690, 1.731] ✓
sfh_dpl_beta 1.938 1.990 [ 1.957, 2.033] ✗
sfh_dpl_log_peak_sfr 1.000 0.925 [ 0.878, 0.968] ✗
sfh_dpl_tau_gyr 11.638 9.612 [ 8.363, 10.671] ✗
[13]:
# Summary statistics
n_nuts = len(next(iter(result_nuts_joint.samples.values())))
ess_per_sec = n_nuts / t_nuts_joint if t_nuts_joint > 0 else 0
print("\nNUTS joint summary:")
print(f" samples: {n_nuts}")
print(f" wall time: {t_nuts_joint:.1f} s")
print(f" divergent: {result_nuts_joint.diagnostics.get('n_divergent', 'n/a')}")
NUTS joint summary:
samples: 600
wall time: 421.8 s
divergent: 0
[14]:
print("Joint photometry + spectroscopy fit complete")
print("\nKey finding: Joint data breaks degeneracies visible in single-modality fits\n")
======================================================================
✓ Joint photometry + spectroscopy fit complete
======================================================================
Key finding: Joint data breaks degeneracies visible in single-modality fits
[15]:
# Final citation
from contextlib import suppress
with suppress(Exception):
tg.cite(result_nuts_joint)
component name citation
───────── ────── ────────────────────────────────────
framework tengri Cooray et al. (2026, Paper I)
ssp DSPS Hearin et al. 2023 (MNRAS 521, 1741)
framework JAX Bradbury et al. 2018
[3 results — framework]
% ────────────────────────────────────────────────────────────────
% Citations for 3 components used by the model. Paste into your .bib file.
% ────────────────────────────────────────────────────────────────
% [framework] tengri
@article{Cooray_2026,
author = {{Cooray}, Suchetha},
title = {{tengri: Differentiable SED fitting with Information-Field-Theory star formation history priors. I. Framework and mock recovery}},
year = {2026},
journal = {in preparation},
}
% [ssp] DSPS
@article{Hearin_2023,
author = {{Hearin}, Andrew P. and {Chaves-Montero}, Jon{\'a}s and {Alarcon}, Alex and {Becker}, Matthew R. and {Benson}, Andrew},
title = {{DSPS: Differentiable stellar population synthesis}},
year = {2023},
journal = {\mnras},
doi = {10.1093/mnras/stad456},
archivePrefix = {arXiv},
eprint = {2112.06830},
}
% [framework] JAX
@article{Jamesbradbury_2018,
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
year = {2018},
}
[16]:
print("Joint photometry + spectroscopy fitting (NUTS) complete.")
✓ Joint photometry + spectroscopy fitting (NUTS) complete.
Next Steps¶
`08_sfh_advanced.py<08_sfh_advanced.py>`__ — Stochastic SFH constraints via joint inference`09_dust_emission.py<09_dust_emission.py>`__ — IR emission physics and template degeneracies`10_agn_advanced.py<10_agn_advanced.py>`__ — AGN diagnostics and multi-wavelength constraints