"""Green's functions and window functions for SFH sensitivity.
The Green's function G_lambda(t_age) tells you how much a stellar
population of age t_age contributes to the flux at wavelength lambda
(or through filter lambda). This is the SSP mass-to-light ratio.
The window function W_lambda(t_age) = G_lambda(t_age) * <SFR(t_age)>
weights the Green's function by the mean SFH, telling you which
lookback times actually contribute to the observed flux.
Together these answer:
- "Which ages does H-alpha probe?" (young, ~few Myr)
- "Which ages does the Balmer break probe?" (intermediate, ~100 Myr-1 Gyr)
- "Which ages does the K-band probe?" (old, ~Gyr)
- "At what timescales can I constrain PSD power?"
This is the time-domain complement to the gradient SEDs in saliency.py.
Connects to Munoz+2026 Eq. 11 and Iyer+2024 window function formalism.
Key insight: the Fourier transform of W gives the frequency-domain
sensitivity, directly linking observable bands to PSD timescales.
"""
import jax
import jax.numpy as jnp
[docs]
def compute_green_function(
ssp_flux_at_z, ssp_wave, filter_wave=None, filter_trans=None, wave_target=None
):
"""Compute Green's function G(t_age) for a filter or wavelength.
G(t_age) = flux contribution per unit stellar mass at age t_age.
For photometry: G = int L_SSP(lambda|t_age) * T(lambda) * lambda dlambda
/ int T(lambda) * lambda dlambda
For a single wavelength: G = L_SSP(wave_target | t_age)
Parameters
----------
ssp_flux_at_z : array, shape (n_age, n_wave)
SSP spectra at fixed metallicity.
ssp_wave : array, shape (n_wave,)
Wavelength grid (Angstrom).
filter_wave : array, optional
Filter wavelength grid.
filter_trans : array, optional
Filter transmission.
wave_target : float, optional
Single wavelength (Angstrom). Used if no filter provided.
Returns
-------
array, shape (n_age,)
Green's function G(t_age).
"""
n_age = ssp_flux_at_z.shape[0]
if filter_wave is not None and filter_trans is not None:
# Photometric Green's function
denom = jnp.trapezoid(filter_trans * filter_wave, filter_wave)
greens = jnp.zeros(n_age)
for i in range(n_age):
ssp_on_filt = jnp.interp(filter_wave, ssp_wave, ssp_flux_at_z[i], left=0.0, right=0.0)
num = jnp.trapezoid(ssp_on_filt * filter_trans * filter_wave, filter_wave)
greens = greens.at[i].set(num / jnp.maximum(denom, 1e-30))
return greens
elif wave_target is not None:
# Monochromatic Green's function
return jax.vmap(lambda ssp: jnp.interp(wave_target, ssp_wave, ssp))(ssp_flux_at_z)
else:
raise ValueError("Provide either (filter_wave, filter_trans) or wave_target")
[docs]
def compute_window_function(green_fn, mean_sfr_on_ages):
"""Compute window function W(t_age) = G(t_age) * <SFR(t_age)>.
The window function tells you which lookback times actually
contribute to the observed flux, given the galaxy's mean SFH.
Parameters
----------
green_fn : array, shape (n_age,)
Green's function G(t_age).
mean_sfr_on_ages : array, shape (n_age,)
Mean SFR evaluated at the SSP age grid (Msun/yr).
Returns
-------
array, shape (n_age,)
Window function W(t_age) (unnormalized).
"""
return green_fn * mean_sfr_on_ages
[docs]
def compute_window_function_fourier(window_fn, ssp_ages_yr):
"""Compute Fourier transform of window function.
|W_tilde(omega)|^2 tells you the PSD sensitivity: how much power
at frequency omega contributes to the variance of this observable.
From Munoz+2026 Eq. 11:
sigma^2_L = int |W_tilde(omega)|^2 P(omega) d_omega / (2*pi)
Parameters
----------
window_fn : array, shape (n_age,)
Window function W(t_age).
ssp_ages_yr : array, shape (n_age,)
SSP ages in years (for frequency axis).
Returns
-------
power_transfer : array, shape (n_freq,)
|W_tilde(omega)|^2 — the PSD-to-observable transfer function.
omega : array, shape (n_freq,)
Angular frequencies (rad/yr).
"""
n = len(window_fn)
# FFT of window function (on the SSP age grid)
# Use log-spaced ages, so compute dt for proper normalization
dt = jnp.diff(ssp_ages_yr, prepend=0.0)
w_tilde = jnp.fft.rfft(window_fn * dt)
power_transfer = jnp.abs(w_tilde) ** 2
# Frequency grid
d_age = jnp.mean(jnp.diff(ssp_ages_yr))
freqs = jnp.fft.rfftfreq(n, d=d_age)
omega = 2.0 * jnp.pi * freqs
return power_transfer, omega
[docs]
def compute_time_sensitivity_matrix(ssp_flux_at_z, ssp_wave, wavelengths_target):
"""Compute sensitivity of multiple wavelengths to different ages.
Returns a matrix G(wavelength, age) showing which wavelengths
are sensitive to which stellar population ages.
Parameters
----------
ssp_flux_at_z : array, shape (n_age, n_wave)
SSP spectra at fixed metallicity.
ssp_wave : array, shape (n_wave,)
Wavelength grid (Angstrom).
wavelengths_target : array, shape (n_target,)
Target wavelengths to evaluate (Angstrom).
E.g., [1500, 2500, 4000, 5500, 6563, 8000, 16000] for
FUV, NUV, Balmer break, V-band, H-alpha, I-band, H-band.
Returns
-------
sensitivity : array, shape (n_target, n_age)
G(wavelength, age) matrix. Each row is the Green's function
at that wavelength.
"""
n_target = len(wavelengths_target)
n_age = ssp_flux_at_z.shape[0]
sensitivity = jnp.zeros((n_target, n_age))
for w_idx, wave in enumerate(wavelengths_target):
g = compute_green_function(ssp_flux_at_z, ssp_wave, wave_target=wave)
sensitivity = sensitivity.at[w_idx].set(g)
return sensitivity