Utilities

Grid construction, GP generation, and helper functions used throughout the pipeline.

make_log_age_grid

tengri.make_log_age_grid(n_grid: int = 256, log_age_min: float = 6.0, log_age_max: float = 10.14) Array[source]

Create uniform grid in log10(age/yr).

Default range: 1 Myr to ~13.8 Gyr (approximately the age of the universe).

Parameters:
  • n_grid (int, optional) – Number of grid points (should be even for FFT efficiency). Default: 256.

  • log_age_min (float, optional) – Minimum log10(age/yr). Default: 6.0 (1 Myr).

  • log_age_max (float, optional) – Maximum log10(age/yr). Default: 10.14 (~13.8 Gyr).

Returns:

Uniform grid in log10(age/yr) [dimensionless log values].

Return type:

ndarray, shape (n_grid,)

Notes

JIT-compatible: yes — uses jnp.linspace.

This grid is used as the internal representation for age in GP-based SFH models. The log-space parametrization provides better resolution at young ages and maps naturally to the logarithmic timescales of stellar evolution.

Examples

>>> from tengri import make_log_age_grid
>>> grid = make_log_age_grid(n_grid=64)
>>> grid.shape
(64,)
>>> float(grid[0]), float(grid[-1])
(6.0, 10.14)

compute_field_gp

tengri.compute_field_gp(xi: Array, psd_sigma: float, psd_tau_yr: float, n_grid: int, d_log_age: float, field_model: str = 'drw') tuple[Array, float][source]

Compute GP realization and lognormal correction for the field component.

Parameters:
  • xi (array, shape (n_grid,)) – Latent vector (xi ~ N(0, I)).

  • psd_sigma (float) – PSD amplitude (dex).

  • psd_tau_yr (float) – PSD timescale (yr).

  • n_grid (int) – Grid size.

  • d_log_age (float) – Grid spacing in dex.

  • field_model (str) – PSD model name. Default “drw”.

Returns:

  • gp_x (array, shape (n_grid,)) – GP realization on the log-age grid.

  • k0_half (float) – Lognormal correction: K(0)/2 = sigma_PS^2 / 4.

Notes

JIT-compatible: yes — uses gp_from_xi and PSD model functions from the field model registry.

The Gaussian process realization models burstiness via a correlated random field in log-space. The PSD model (e.g., “drw” for Damped Random Walk) controls temporal correlations. The lognormal correction k0_half accounts for the bias introduced when exponentiation is applied to the Gaussian latents.

Examples

>>> import jax.numpy as jnp
>>> from tengri import compute_field_gp, make_log_age_grid
>>> n = 64
>>> grid = make_log_age_grid(n)
>>> d = float(grid[1] - grid[0])
>>> xi = jnp.zeros(n)
>>> gp_x, k0_half = compute_field_gp(xi, psd_sigma=1.0, psd_tau_yr=1e8, n_grid=n, d_log_age=d)
>>> gp_x.shape
(64,)

generate_gp_fourier

tengri.generate_gp_fourier(key: Array, sqrt_power: Array, n_points: int) Array[source]

Stochastic GP realization for mock galaxy generation.

Draws a random standardized vector and maps it to a correlated GP field.

keyjax.random.PRNGKey

JAX random key for reproducibility.

sqrt_powerarray_like, shape (n_freq,)

Amplitude operator :math:`sqrt{P(omega) / d_{

m grid}}` at rfft frequencies

(pre-compute with psd_to_sqrt_power()). [dimensionless]

n_pointsint

Number of grid points.

ndarray, shape (n_points,)

GP realization on the log-age grid [dimensionless].

JIT-compatible: yes — uses jax.random.normal and gp_from_xi().

This function is the primary interface for generating mock SFHs with stochastic variability. The random draw is always independent; for reproducibility, pass the same PRNGKey.

gp_from_xi : Deterministic GP mapping (used internally). generate_gp_batch : Generate multiple independent realizations.

>>> import jax
>>> from tengri import generate_gp_fourier, make_log_age_grid, compute_sqrt_power_drw
>>> n = 64
>>> grid = make_log_age_grid(n)
>>> d = float(grid[1] - grid[0])
>>> sqrt_power = compute_sqrt_power_drw(n, d, psd_sigma=1.0, psd_tau_yr=1e8)
>>> key = jax.random.PRNGKey(0)
>>> sfh = generate_gp_fourier(key, sqrt_power, n)
>>> sfh.shape
(64,)

generate_gp_batch

tengri.generate_gp_batch(key: Array, sqrt_power: Array, n_points: int, n_realizations: int) Array[source]

Batch of independent GP realizations via vectorization.

Generates multiple independent SFH realizations in parallel using vmap.

keyjax.random.PRNGKey

JAX random key (will be split into n_realizations independent keys).

sqrt_powerarray_like, shape (n_freq,)

Amplitude operator :math:`sqrt{P(omega) / d_{

m grid}}` at rfft frequencies.

[dimensionless]

n_pointsint

Number of grid points.

n_realizationsint

Number of independent realizations to generate.

ndarray, shape (n_realizations, n_points)

Batch of independent GP realizations [dimensionless].

JIT-compatible: yes — uses jax.vmap over generate_gp_fourier().

Each realization is independent with the specified PSD structure. This function is useful for generating mock catalogs or computing uncertainties via Monte Carlo sampling.

generate_gp_fourier : Single realization.

>>> import jax
>>> from tengri import generate_gp_batch, make_log_age_grid, compute_sqrt_power_drw
>>> n = 64
>>> grid = make_log_age_grid(n)
>>> d = float(grid[1] - grid[0])
>>> sqrt_power = compute_sqrt_power_drw(n, d, psd_sigma=1.0, psd_tau_yr=1e8)
>>> key = jax.random.PRNGKey(0)
>>> batch = generate_gp_batch(key, sqrt_power, n_points=n, n_realizations=10)
>>> batch.shape
(10, 64)

gp_from_xi

tengri.gp_from_xi(xi: Array, sqrt_power: Array, n_points: int) Array[source]

Deterministic GP realization from standardized latent vector.

Maps a standardized Gaussian random vector to a correlated GP field via Fourier-space multiplication with a spectral amplitude operator. This is the core function used during inference and mock generation.

xiarray_like, shape (n_points,)

Standardized latent vector \(\xi \sim \mathcal{N}(0, I)\) under the prior.

sqrt_powerarray_like, shape (n_freq,)

Amplitude operator :math:`sqrt{P(omega) / d_{

m grid}}` at rfft frequencies

(pre-compute with psd_to_sqrt_power()). [dimensionless]

n_pointsint

Number of grid points (should match the length of xi).

ndarray, shape (n_points,)

GP realization on the log-age grid [dimensionless].

JIT-compatible: yes — uses jnp.fft.rfft and jnp.fft.irfft.

Gradient-safe: yes — differentiable w.r.t. sqrt_power.

Implements the NIFTy correlated field model:

\[s = \mathrm{IFFT}(\sqrt{P} \cdot \hat{\xi})\]

The rfft (real FFT) preserves Hermitian symmetry for real-valued output and ensures correct variance normalization: \(E[|\mathrm{rfft}(\xi)_k|^2] = N\), so with \(\sqrt{P/\Delta x}\) we recover \(\mathrm{Var}[s] = \int P(f) df\).

This is the primary function called during MCMC inference and mock galaxy generation. The sampler proposes values of \(\xi\) and this function maps them to correlated SFH realizations.

>>> import jax.numpy as jnp
>>> from tengri import gp_from_xi, make_log_age_grid, compute_sqrt_power_drw
>>> n = 64
>>> grid = make_log_age_grid(n)
>>> d = float(grid[1] - grid[0])
>>> sqrt_power = compute_sqrt_power_drw(n, d, psd_sigma=1.0, psd_tau_yr=1e8)
>>> xi = jnp.zeros(n)
>>> sfh = gp_from_xi(xi, sqrt_power, n)
>>> sfh.shape
(64,)

Noise Utilities

tengri.compute_effective_noise(noise_obs: Array, model_flux: Array, f_cal: float | Array) Array[source]

Compute effective noise with calibration floor.

σ_eff = sqrt(σ²_obs + (f_cal · model)²)

Parameters:
  • noise_obs (array, shape (n_bands,)) – Observed 1-sigma uncertainties [flux units].

  • model_flux (array, shape (n_bands,)) – Model-predicted fluxes [flux units] (absolute value used for calibration term).

  • f_cal (float or scalar array) – Fractional calibration uncertainty [dimensionless]. Typical range: 0.01–0.15.

Returns:

Effective noise standard deviation [same units as inputs].

Return type:

array, shape (n_bands,)

Notes

JIT-compatible: yes — uses only jnp primitives.

The calibration term f_cal * |model| adds a flux-dependent floor to the noise budget, preventing zero-noise solutions when measurement uncertainties are very small.

Examples

>>> import jax.numpy as jnp
>>> from tengri import compute_effective_noise
>>> noise = jnp.array([0.1, 0.2, 0.15])
>>> model = jnp.array([1.0, 2.0, 1.5])
>>> sigma_eff = compute_effective_noise(noise, model, f_cal=0.05)
>>> sigma_eff.shape
(3,)
tengri.compute_std_inv(noise_obs: Array, model_flux: Array, f_cal: float | Array) Array[source]

Compute inverse effective noise (precision).

τ = 1/σ_eff. This is the second output expected by NIFTy’s VariableCovarianceGaussian likelihood.

Parameters:
  • noise_obs (array, shape (n_bands,)) – Observed 1-sigma uncertainties [flux units].

  • model_flux (array, shape (n_bands,)) – Model-predicted fluxes [flux units].

  • f_cal (float or scalar array) – Fractional calibration uncertainty [dimensionless].

Returns:

Inverse noise standard deviation τ = 1/σ_eff [1/flux_units].

Return type:

array, shape (n_bands,)

Notes

JIT-compatible: yes — delegates to compute_effective_noise() which is pure JAX.

Used in variable-covariance likelihoods where the noise is a traced parameter. See variable_noise_hamiltonian() for integration into the likelihood energy function.

Examples

>>> import jax.numpy as jnp
>>> from tengri import compute_std_inv
>>> noise = jnp.array([0.1, 0.2, 0.15])
>>> model = jnp.array([1.0, 2.0, 1.5])
>>> tau = compute_std_inv(noise, model, f_cal=0.05)
>>> tau.shape
(3,)
tengri.has_noise_model(spec) bool[source]

Check if any noise parameter is free (not Fixed at 0).

Parameters:

spec (Parameters) – Parameter specification.

Returns:

True if the noise model is active (any noise parameter is free or Fixed to nonzero value).

Return type:

bool

Notes

Not JIT-compatible (uses Python control flow and class introspection).

This function checks if any parameter whose name starts with "noise_" is in the free parameter list, or if noise_frac_cal is explicitly fixed to a nonzero value.

Examples

>>> from tengri import Parameters, Uniform, has_noise_model
>>> spec = Parameters(dust_tau_bc=Uniform(0.1, 4.0))
>>> has_noise_model(spec)
False
>>> spec2 = Parameters(dust_tau_bc=Uniform(0.1, 4.0), noise_frac_cal=Uniform(0.01, 0.2))
>>> has_noise_model(spec2)
True