Distributions

Prior distributions for model parameters. Each distribution provides sample, log_prob, and transform methods compatible with JAX tracing.

Uniform

class tengri.Uniform(lo: float, hi: float)[source]

Bases: Distribution

Uniform prior on [lo, hi].

A flat probability density on the interval [lo, hi]. Commonly used for bounded astrophysical quantities with little prior knowledge. Reparameterizes via sigmoid to ensure differentiability and automatic bound satisfaction.

Parameters:
  • lo (float) – Lower bound (inclusive).

  • hi (float) – Upper bound (inclusive). Must satisfy hi > lo.

lo

Lower bound of the distribution.

Type:

float

hi

Upper bound of the distribution.

Type:

float

bounds

(lo, hi) convenience tuple.

Type:

tuple[float, float]

Raises:

ValueError – If lo >= hi.

Notes

JIT-compatible: yes — all operations use jnp primitives.

Standardization: Maps ξ ~ N(0,1) to θ via the sigmoid function:

\[\theta = lo + (hi - lo) \cdot \sigma(\xi)\]

where σ(ξ) = 1 / (1 + exp(-ξ)). At ξ = 0 (prior center), θ = (lo + hi) / 2. This ensures automatic bound satisfaction and smooth gradients.

Examples

>>> import jax.random
>>> from tengri import Uniform
>>> prior = Uniform(0, 1)
>>> key = jax.random.PRNGKey(0)
>>> sample = prior.sample(key)
>>> print(f"Sample: {sample:.4f}")  # Will be in [0, 1)
>>> log_prob = prior.log_prob(0.5)
>>> print(f"log p(0.5): {log_prob:.4f}")  # ≈ 0.0 (log(1) = 0)
property bounds: tuple[float, float]

Lower and upper bounds [lo, hi].

Returns:

Bounds as (lo, hi) tuple.

Return type:

tuple[float, float]

property hi: float

Upper bound of the uniform distribution.

Returns:

Upper bound value.

Return type:

float

property lo: float

Lower bound of the uniform distribution.

Returns:

Lower bound value.

Return type:

float

log_prob(x: Array) Array[source]

Return log probability: -log(hi-lo) inside bounds, -inf outside.

Parameters:

x (float or array_like) – Parameter value in physical space.

Returns:

Log probability density at x.

Return type:

float

sample(key: Array) Array[source]

Draw one sample uniformly from [lo, hi].

Parameters:

key (jax.Array) – JAX PRNG key for random sampling.

Returns:

A single sample uniformly distributed in [lo, hi].

Return type:

ndarray

standardize(theta: Array) Array[source]

Uniform(lo, hi) → ξ via logit.

Parameters:

theta (float or array_like) – Physical-space parameter value in [lo, hi].

Returns:

Standardized latent-space value.

Return type:

float or ndarray

unstandardize(xi: Array) Array[source]

ξ ~ N(0,1) → Uniform(lo, hi) via sigmoid.

At ξ=0 (prior center), sigmoid(0) = 0.5, so θ = midpoint of [lo, hi]. At ξ=±3 (~99.7% of N(0,1) mass), θ covers ~95% of [lo, hi]. The sigmoid naturally respects bounds without clipping.

Parameters:

xi (float or array_like) – Standardized latent value from standard normal distribution.

Returns:

Physical-space parameter in [lo, hi].

Return type:

float or ndarray

Gaussian

class tengri.Gaussian(mu: float, sigma: float, lo: float = -inf, hi: float = inf)[source]

Bases: Distribution

Gaussian (normal) prior, optionally clipped to [lo, hi].

A bell-curve probability density centered at μ with standard deviation σ. Useful when prior information suggests a most-probable value with uncertainty. Optional bounds allow truncation to physical ranges.

Parameters:
  • mu (float) – Mean of the Gaussian distribution.

  • sigma (float) – Standard deviation. Must be positive.

  • lo (float, optional) – Lower truncation bound. Default: -∞ (no lower truncation).

  • hi (float, optional) – Upper truncation bound. Default: +∞ (no upper truncation).

mu

Mean of the distribution.

Type:

float

sigma

Standard deviation of the distribution.

Type:

float

lo

Lower truncation bound (-inf if unbounded).

Type:

float

hi

Upper truncation bound (+inf if unbounded).

Type:

float

bounds

(lo, hi) convenience tuple.

Type:

tuple[float, float]

Raises:

ValueError – If sigma <= 0 or lo >= hi.

Notes

JIT-compatible: yes — all operations use jnp primitives.

Standardization: Maps ξ ~ N(0,1) to θ via:

\[\theta = \text{clip}(\mu + \sigma \cdot \xi, lo, hi)\]

Normalization: When lo, hi are finite, the density is normalized over [lo, hi], not over the full real line. Use this for physically bounded quantities.

Examples

>>> import jax.random
>>> from tengri import Gaussian
>>> prior = Gaussian(mu=-0.3, sigma=0.2)  # metallicity
>>> key = jax.random.PRNGKey(0)
>>> sample = prior.sample(key)
>>> print(f"Sample: {sample:.3f}")  # Typically near -0.3
>>> log_prob = prior.log_prob(-0.3)
>>> print(f"log p(μ): {log_prob:.4f}")  # Maximum at mean
property bounds: tuple[float, float]

Lower and upper truncation bounds [lo, hi].

Returns:

Bounds as (lo, hi) tuple.

Return type:

tuple[float, float]

property hi: float

Upper truncation bound.

Returns:

Upper truncation bound (+inf if unbounded).

Return type:

float

property lo: float

Lower truncation bound.

Returns:

Lower truncation bound (-inf if unbounded).

Return type:

float

log_prob(x: Array) Array[source]

Evaluate log probability density, returning -inf outside bounds.

Parameters:

x (float or array_like) – Parameter value in physical space.

Returns:

Log probability density at x.

Return type:

float

property mu: float

Mean of the Gaussian distribution.

Returns:

Mean value.

Return type:

float

sample(key: Array) Array[source]

Draw a random sample from the Gaussian distribution.

Parameters:

key (jax.Array) – JAX PRNG key for random sampling.

Returns:

A single sample from N(mu, sigma²), clipped to [lo, hi].

Return type:

ndarray

property sigma: float

Standard deviation of the Gaussian distribution.

Returns:

Standard deviation value.

Return type:

float

standardize(theta: Array) Array[source]

N(μ,σ²) → ξ.

Parameters:

theta (float or array_like) – Physical-space parameter value.

Returns:

Standardized latent-space value.

Return type:

float or ndarray

unstandardize(xi: Array) Array[source]

ξ ~ N(0,1) → N(μ,σ²) clipped to [lo, hi].

Parameters:

xi (float or array_like) – Standardized latent value from standard normal distribution.

Returns:

Physical-space parameter in [lo, hi].

Return type:

float or ndarray

LogUniform

class tengri.LogUniform(lo: float, hi: float)[source]

Bases: Distribution

Uniform in log10 space on [lo, hi].

A prior that places equal probability in logarithmic intervals, resulting in power-law density in linear space. Useful for quantities with logarithmic uncertainties, such as star formation rates, timescales, and luminosities.

Parameters:
  • lo (float) – Lower bound. Must be strictly positive.

  • hi (float) – Upper bound. Must be greater than lo.

lo

Lower bound of the distribution.

Type:

float

hi

Upper bound of the distribution.

Type:

float

bounds

(lo, hi) convenience tuple.

Type:

tuple[float, float]

Raises:

ValueError – If lo <= 0 or lo >= hi.

Notes

JIT-compatible: yes — all operations use jnp primitives.

The probability density in linear space is:

\[p(x) = \frac{1}{x \cdot \ln(10) \cdot \log_{10}(hi / lo)}\]

where x ∈ [lo, hi]. Samples drawn from this distribution are equally spaced in log10 space: log10(x) ~ U(log10(lo), log10(hi)).

Standardization: Maps ξ ~ N(0,1) to θ via sigmoid in log space:

\[\theta = 10^{\log_{10}(lo) + (\log_{10}(hi) - \log_{10}(lo)) \cdot \sigma(\xi)}\]

Examples

>>> import jax.random
>>> from tengri import LogUniform
>>> prior = LogUniform(1e-2, 1e2)  # ~4 orders of magnitude
>>> key = jax.random.PRNGKey(0)
>>> sample = prior.sample(key)
>>> print(f"Sample: {sample:.3e}")
>>> log_prob = prior.log_prob(1.0)  # Center of log space
>>> print(f"log p(1.0): {log_prob:.4f}")
property bounds: tuple[float, float]

Lower and upper bounds [lo, hi].

Returns:

Bounds as (lo, hi) tuple.

Return type:

tuple[float, float]

property hi: float

Upper bound of the log-uniform distribution.

Returns:

Upper bound value.

Return type:

float

property lo: float

Lower bound of the log-uniform distribution.

Returns:

Lower bound value.

Return type:

float

log_prob(x: Array) Array[source]

Return log probability: -log(x * log(hi/lo)) inside bounds, -inf outside.

Parameters:

x (float or array_like) – Parameter value in physical space.

Returns:

Log probability density at x.

Return type:

float

sample(key: Array) Array[source]

Draw one sample log-uniformly from [lo, hi].

Parameters:

key (jax.Array) – JAX PRNG key for random sampling.

Returns:

A single sample log-uniformly distributed in [lo, hi].

Return type:

ndarray

standardize(theta: Array) Array[source]

LogUniform(lo, hi) → ξ via logit in log space.

Parameters:

theta (float or array_like) – Physical-space parameter value in [lo, hi].

Returns:

Standardized latent-space value.

Return type:

float or ndarray

unstandardize(xi: Array) Array[source]

ξ ~ N(0,1) → LogUniform(lo, hi) via sigmoid in log space.

Parameters:

xi (float or array_like) – Standardized latent value from standard normal distribution.

Returns:

Physical-space parameter in [lo, hi].

Return type:

float or ndarray

LogNormal

class tengri.LogNormal(mu: float = 0.0, sigma: float = 1.0, lo: float = 0.0, hi: float = inf)[source]

Bases: Distribution

Log-normal prior: log(θ) ~ N(μ, σ²).

A prior suitable for positive-definite quantities with multiplicative uncertainty, such as timescales, amplitudes, and scale factors. The log of the parameter is normally distributed.

Parameters:
  • mu (float, optional) – Mean of log(θ). Default: 0.0.

  • sigma (float, optional) – Standard deviation of log(θ). Must be positive. Default: 1.0.

  • lo (float, optional) – Lower truncation bound. Default: 0.0 (ensures θ > 0).

  • hi (float, optional) – Upper truncation bound. Default: +∞ (no upper truncation).

mu

Mean of log(theta).

Type:

float

sigma

Standard deviation of log(theta).

Type:

float

bounds

(lo, hi) convenience tuple.

Type:

tuple[float, float]

Raises:

ValueError – If sigma <= 0.

Notes

JIT-compatible: yes — all operations use jnp primitives.

The probability density in linear space is:

\[p(\theta) = \frac{1}{\theta \sigma \sqrt{2\pi}} \exp\left( -\frac{(\ln \theta - \mu)^2}{2\sigma^2} \right)\]

for θ ∈ [lo, hi]. When truncated, the density is renormalized over the interval.

Standardization: Maps ξ ~ N(0,1) to θ via:

\[\theta = \text{clip}(\exp(\mu + \sigma \cdot \xi), lo, hi)\]

Examples

>>> import jax.random
>>> from tengri import LogNormal
>>> # PSD timescale: log(tau_yr) centered at 8, width 0.5 dex
>>> prior = LogNormal(mu=8, sigma=0.5)
>>> key = jax.random.PRNGKey(0)
>>> sample = prior.sample(key)
>>> print(f"Sample (yr): {sample:.3e}")
>>> log_prob = prior.log_prob(1e8)
>>> print(f"log p(1e8): {log_prob:.4f}")
property bounds: tuple[float, float]

Lower and upper truncation bounds [lo, hi].

Returns:

Bounds as (lo, hi) tuple.

Return type:

tuple[float, float]

log_prob(x: Array) Array[source]

Evaluate log probability density, returning -inf outside bounds.

Parameters:

x (float or array_like) – Parameter value in physical space.

Returns:

Log probability density at x.

Return type:

float

property mu: float

Mean of log(theta).

Returns:

Mean of the logarithm of the parameter.

Return type:

float

sample(key: Array) Array[source]

Draw a random sample from the log-normal distribution.

Parameters:

key (jax.Array) – JAX PRNG key for random sampling.

Returns:

A single sample from LogNormal(mu, sigma²), clipped to [lo, hi].

Return type:

ndarray

property sigma: float

Standard deviation of log(theta).

Returns:

Standard deviation of the logarithm of the parameter.

Return type:

float

standardize(theta: Array) Array[source]

LogNormal → ξ.

Parameters:

theta (float or array_like) – Physical-space parameter value.

Returns:

Standardized latent-space value.

Return type:

float or ndarray

unstandardize(xi: Array) Array[source]

ξ ~ N(0,1) → exp(μ + σ·ξ), clipped to [lo, hi].

Parameters:

xi (float or array_like) – Standardized latent value from standard normal distribution.

Returns:

Physical-space parameter in [lo, hi].

Return type:

float or ndarray

StudentT

class tengri.StudentT(mu: float = 0.0, sigma: float = 1.0, df: float = 3.0, lo: float = -inf, hi: float = inf)[source]

Bases: Distribution

Student’s t prior with heavier tails than Gaussian.

A robust prior with longer tails, useful for parameters that may exhibit outlier-like behavior. Commonly used in BAGPIPES-style SED fitting for down-weighting extreme values while remaining flexible.

Parameters:
  • mu (float, optional) – Location (center) of the distribution. Default: 0.0.

  • sigma (float, optional) – Scale parameter. Must be positive. Default: 1.0.

  • df (float, optional) – Degrees of freedom. Controls tail weight: - df → ∞ gives Gaussian (heaviest concentration at center) - df = 3 gives a moderately heavy-tailed prior - df = 1 gives Cauchy (extremely heavy tails) Default: 3.0.

  • lo (float, optional) – Lower truncation bound. Default: -∞ (no lower truncation).

  • hi (float, optional) – Upper truncation bound. Default: +∞ (no upper truncation).

bounds

(lo, hi) truncation bounds.

Type:

tuple[float, float]

Notes

JIT-compatible: yes — all operations use jnp primitives.

The probability density follows a Student’s t distribution with the standard normalisation. For finite df, it has heavier tails than a Gaussian.

Standardization: Uses a Gaussian approximation with variance scaling:

\[\theta = \text{clip}(\mu + \sigma \cdot \sqrt{df/(df-2)} \cdot \xi, lo, hi)\]

This is valid for df > 2. For df ≤ 2, a fallback scale of 3 is used.

Examples

>>> import jax.random
>>> from tengri import StudentT
>>> prior = StudentT(mu=0, sigma=1, df=3)  # Robust prior
>>> key = jax.random.PRNGKey(0)
>>> sample = prior.sample(key)
>>> print(f"Sample: {sample:.4f}")
property bounds: tuple[float, float]

Lower and upper truncation bounds [lo, hi].

Returns:

Bounds as (lo, hi) tuple.

Return type:

tuple[float, float]

log_prob(x: Array) Array[source]

Evaluate log probability density, returning -inf outside bounds.

Parameters:

x (float or array_like) – Parameter value in physical space.

Returns:

Log probability density at x.

Return type:

float

sample(key: Array) Array[source]

Draw a random sample from the Student’s t distribution.

Parameters:

key (jax.Array) – JAX PRNG key for random sampling.

Returns:

A single sample from Student’s t distribution, clipped to [lo, hi].

Return type:

ndarray

standardize(theta: Array) Array[source]

Map a physical parameter value to a standardized coordinate via the Student-t scale.

Parameters:

theta (float or array_like) – Physical-space parameter value.

Returns:

Standardized latent-space value.

Return type:

float or ndarray

unstandardize(xi: Array) Array[source]

ξ ~ N(0,1) → t-distributed via Gaussian approximation.

For df>2, a Gaussian with matched variance is a reasonable approximation for the bulk of the distribution.

Parameters:

xi (float or array_like) – Standardized latent value from standard normal distribution.

Returns:

Physical-space parameter in [lo, hi].

Return type:

float or ndarray

Fixed

class tengri.Fixed(value: float | str)[source]

Bases: Distribution

Fixed (non-free) parameter with a constant value.

Represents a parameter that is not sampled or inferred. Used for holding model settings constant during fitting, or for categorical parameters that don’t vary. Fixed parameters contribute zero to the likelihood.

Parameters:

value (float, int, or str) – The fixed value. Can be numeric (for quantitative parameters) or string (for categorical choices, e.g. “solar” for shock abundance).

Returns:

Fixed instance with the given value.

Return type:

Fixed

value

The constant value returned by sample() and unstandardize().

Type:

float or str

bounds

Always (-inf, +inf) — Fixed parameters have no bounds.

Type:

tuple[float, float]

Notes

JIT-compatible: yes — unstandardize() returns the constant value regardless of the latent variable ξ.

Inference: Fixed parameters are excluded from the inference set. They do not appear in the posterior and do not contribute to the loss or gradients.

Examples

>>> from tengri import Fixed
>>> # Numerical fixed value
>>> redshift = Fixed(0.1)
>>> print(redshift.sample(None))
0.1
>>> # Categorical fixed value
>>> shock_abundance = Fixed("solar")
>>> print(shock_abundance.sample(None))
solar
property bounds: tuple[float, float] | tuple[None, None]

Return (value, value) for numeric, or (None, None) for string.

Returns:

For numeric values: (value, value); for string values: (None, None).

Return type:

tuple[float, float] or tuple[None, None]

property is_fixed: bool

Return True — this is a fixed (non-free) parameter.

Returns:

Always True for Fixed distributions.

Return type:

bool

log_prob(x: Array) Array[source]

Return 0.0 (fixed parameters have zero log-likelihood contribution).

Parameters:

x (float or array_like) – Parameter value (ignored for fixed distributions).

Returns:

Always 0.0.

Return type:

float

sample(key: Array) Array | str[source]

Return the fixed value (ignores random key).

Parameters:

key (jax.Array) – JAX PRNG key (ignored for fixed parameters).

Returns:

The constant fixed value.

Return type:

float, int, or str

standardize(theta: Array) Array[source]

Fixed: returns 0 (no latent variable needed).

Parameters:

theta (float or array_like) – Physical-space parameter value (ignored for fixed parameters).

Returns:

Always 0.0.

Return type:

float

unstandardize(xi: Array) Array | str[source]

Fixed: always returns the fixed value (ignores ξ).

Parameters:

xi (float or array_like) – Standardized latent value (ignored for fixed parameters).

Returns:

The constant fixed value.

Return type:

float, int, or str

property value: float | str

The fixed value (numeric or string).

Returns:

The constant fixed value.

Return type:

float or str