Utilities¶
Grid construction, GP generation, and helper functions used throughout the pipeline.
make_log_age_grid¶
- tengri.make_log_age_grid(n_grid: int = 256, log_age_min: float = 6.0, log_age_max: float = 10.14) Array[source]¶
Create uniform grid in log10(age/yr).
Default range: 1 Myr to ~13.8 Gyr (approximately the age of the universe).
- Parameters:
- Returns:
Uniform grid in log10(age/yr) [dimensionless log values].
- Return type:
Notes
JIT-compatible: yes — uses
jnp.linspace.This grid is used as the internal representation for age in GP-based SFH models. The log-space parametrization provides better resolution at young ages and maps naturally to the logarithmic timescales of stellar evolution.
Examples
>>> from tengri import make_log_age_grid >>> grid = make_log_age_grid(n_grid=64) >>> grid.shape (64,) >>> float(grid[0]), float(grid[-1]) (6.0, 10.14)
compute_field_gp¶
- tengri.compute_field_gp(xi: Array, psd_sigma: float, psd_tau_yr: float, n_grid: int, d_log_age: float, field_model: str = 'drw') tuple[Array, float][source]¶
Compute GP realization and lognormal correction for the field component.
- Parameters:
- Returns:
gp_x (array, shape (n_grid,)) – GP realization on the log-age grid.
k0_half (float) – Lognormal correction: K(0)/2 = sigma_PS^2 / 4.
Notes
JIT-compatible: yes — uses
gp_from_xiand PSD model functions from the field model registry.The Gaussian process realization models burstiness via a correlated random field in log-space. The PSD model (e.g., “drw” for Damped Random Walk) controls temporal correlations. The lognormal correction
k0_halfaccounts for the bias introduced when exponentiation is applied to the Gaussian latents.Examples
>>> import jax.numpy as jnp >>> from tengri import compute_field_gp, make_log_age_grid >>> n = 64 >>> grid = make_log_age_grid(n) >>> d = float(grid[1] - grid[0]) >>> xi = jnp.zeros(n) >>> gp_x, k0_half = compute_field_gp(xi, psd_sigma=1.0, psd_tau_yr=1e8, n_grid=n, d_log_age=d) >>> gp_x.shape (64,)
generate_gp_fourier¶
- tengri.generate_gp_fourier(key: Array, sqrt_power: Array, n_points: int) Array[source]¶
Stochastic GP realization for mock galaxy generation.
Draws a random standardized vector and maps it to a correlated GP field.
- keyjax.random.PRNGKey
JAX random key for reproducibility.
- sqrt_powerarray_like, shape (n_freq,)
Amplitude operator :math:`sqrt{P(omega) / d_{
- m grid}}` at rfft frequencies
(pre-compute with
psd_to_sqrt_power()). [dimensionless]- n_pointsint
Number of grid points.
- ndarray, shape (n_points,)
GP realization on the log-age grid [dimensionless].
JIT-compatible: yes — uses
jax.random.normalandgp_from_xi().This function is the primary interface for generating mock SFHs with stochastic variability. The random draw is always independent; for reproducibility, pass the same PRNGKey.
gp_from_xi : Deterministic GP mapping (used internally). generate_gp_batch : Generate multiple independent realizations.
>>> import jax >>> from tengri import generate_gp_fourier, make_log_age_grid, compute_sqrt_power_drw >>> n = 64 >>> grid = make_log_age_grid(n) >>> d = float(grid[1] - grid[0]) >>> sqrt_power = compute_sqrt_power_drw(n, d, psd_sigma=1.0, psd_tau_yr=1e8) >>> key = jax.random.PRNGKey(0) >>> sfh = generate_gp_fourier(key, sqrt_power, n) >>> sfh.shape (64,)
generate_gp_batch¶
- tengri.generate_gp_batch(key: Array, sqrt_power: Array, n_points: int, n_realizations: int) Array[source]¶
Batch of independent GP realizations via vectorization.
Generates multiple independent SFH realizations in parallel using vmap.
- keyjax.random.PRNGKey
JAX random key (will be split into n_realizations independent keys).
- sqrt_powerarray_like, shape (n_freq,)
Amplitude operator :math:`sqrt{P(omega) / d_{
- m grid}}` at rfft frequencies.
[dimensionless]
- n_pointsint
Number of grid points.
- n_realizationsint
Number of independent realizations to generate.
- ndarray, shape (n_realizations, n_points)
Batch of independent GP realizations [dimensionless].
JIT-compatible: yes — uses
jax.vmapovergenerate_gp_fourier().Each realization is independent with the specified PSD structure. This function is useful for generating mock catalogs or computing uncertainties via Monte Carlo sampling.
generate_gp_fourier : Single realization.
>>> import jax >>> from tengri import generate_gp_batch, make_log_age_grid, compute_sqrt_power_drw >>> n = 64 >>> grid = make_log_age_grid(n) >>> d = float(grid[1] - grid[0]) >>> sqrt_power = compute_sqrt_power_drw(n, d, psd_sigma=1.0, psd_tau_yr=1e8) >>> key = jax.random.PRNGKey(0) >>> batch = generate_gp_batch(key, sqrt_power, n_points=n, n_realizations=10) >>> batch.shape (10, 64)
gp_from_xi¶
- tengri.gp_from_xi(xi: Array, sqrt_power: Array, n_points: int) Array[source]¶
Deterministic GP realization from standardized latent vector.
Maps a standardized Gaussian random vector to a correlated GP field via Fourier-space multiplication with a spectral amplitude operator. This is the core function used during inference and mock generation.
- xiarray_like, shape (n_points,)
Standardized latent vector \(\xi \sim \mathcal{N}(0, I)\) under the prior.
- sqrt_powerarray_like, shape (n_freq,)
Amplitude operator :math:`sqrt{P(omega) / d_{
- m grid}}` at rfft frequencies
(pre-compute with
psd_to_sqrt_power()). [dimensionless]- n_pointsint
Number of grid points (should match the length of xi).
- ndarray, shape (n_points,)
GP realization on the log-age grid [dimensionless].
JIT-compatible: yes — uses
jnp.fft.rfftandjnp.fft.irfft.Gradient-safe: yes — differentiable w.r.t. sqrt_power.
Implements the NIFTy correlated field model:
\[s = \mathrm{IFFT}(\sqrt{P} \cdot \hat{\xi})\]The rfft (real FFT) preserves Hermitian symmetry for real-valued output and ensures correct variance normalization: \(E[|\mathrm{rfft}(\xi)_k|^2] = N\), so with \(\sqrt{P/\Delta x}\) we recover \(\mathrm{Var}[s] = \int P(f) df\).
This is the primary function called during MCMC inference and mock galaxy generation. The sampler proposes values of \(\xi\) and this function maps them to correlated SFH realizations.
>>> import jax.numpy as jnp >>> from tengri import gp_from_xi, make_log_age_grid, compute_sqrt_power_drw >>> n = 64 >>> grid = make_log_age_grid(n) >>> d = float(grid[1] - grid[0]) >>> sqrt_power = compute_sqrt_power_drw(n, d, psd_sigma=1.0, psd_tau_yr=1e8) >>> xi = jnp.zeros(n) >>> sfh = gp_from_xi(xi, sqrt_power, n) >>> sfh.shape (64,)
Noise Utilities¶
- tengri.compute_effective_noise(noise_obs: Array, model_flux: Array, f_cal: float | Array) Array[source]¶
Compute effective noise with calibration floor.
σ_eff = sqrt(σ²_obs + (f_cal · model)²)
- Parameters:
noise_obs (array, shape (n_bands,)) – Observed 1-sigma uncertainties [flux units].
model_flux (array, shape (n_bands,)) – Model-predicted fluxes [flux units] (absolute value used for calibration term).
f_cal (float or scalar array) – Fractional calibration uncertainty [dimensionless]. Typical range: 0.01–0.15.
- Returns:
Effective noise standard deviation [same units as inputs].
- Return type:
array, shape (n_bands,)
Notes
JIT-compatible: yes — uses only jnp primitives.
The calibration term
f_cal * |model|adds a flux-dependent floor to the noise budget, preventing zero-noise solutions when measurement uncertainties are very small.Examples
>>> import jax.numpy as jnp >>> from tengri import compute_effective_noise >>> noise = jnp.array([0.1, 0.2, 0.15]) >>> model = jnp.array([1.0, 2.0, 1.5]) >>> sigma_eff = compute_effective_noise(noise, model, f_cal=0.05) >>> sigma_eff.shape (3,)
- tengri.compute_std_inv(noise_obs: Array, model_flux: Array, f_cal: float | Array) Array[source]¶
Compute inverse effective noise (precision).
τ = 1/σ_eff. This is the second output expected by NIFTy’s
VariableCovarianceGaussianlikelihood.- Parameters:
- Returns:
Inverse noise standard deviation τ = 1/σ_eff [1/flux_units].
- Return type:
array, shape (n_bands,)
Notes
JIT-compatible: yes — delegates to
compute_effective_noise()which is pure JAX.Used in variable-covariance likelihoods where the noise is a traced parameter. See
variable_noise_hamiltonian()for integration into the likelihood energy function.Examples
>>> import jax.numpy as jnp >>> from tengri import compute_std_inv >>> noise = jnp.array([0.1, 0.2, 0.15]) >>> model = jnp.array([1.0, 2.0, 1.5]) >>> tau = compute_std_inv(noise, model, f_cal=0.05) >>> tau.shape (3,)
- tengri.has_noise_model(spec) bool[source]¶
Check if any noise parameter is free (not Fixed at 0).
- Parameters:
spec (Parameters) – Parameter specification.
- Returns:
True if the noise model is active (any noise parameter is free or Fixed to nonzero value).
- Return type:
Notes
Not JIT-compatible (uses Python control flow and class introspection).
This function checks if any parameter whose name starts with
"noise_"is in the free parameter list, or ifnoise_frac_calis explicitly fixed to a nonzero value.Examples
>>> from tengri import Parameters, Uniform, has_noise_model >>> spec = Parameters(dust_tau_bc=Uniform(0.1, 4.0)) >>> has_noise_model(spec) False >>> spec2 = Parameters(dust_tau_bc=Uniform(0.1, 4.0), noise_frac_cal=Uniform(0.01, 0.2)) >>> has_noise_model(spec2) True