Distributions¶
Prior distributions for model parameters. Each distribution provides
sample, log_prob, and transform methods compatible with JAX
tracing.
Uniform¶
- class tengri.Uniform(lo: float, hi: float)[source]¶
Bases:
DistributionUniform prior on [lo, hi].
A flat probability density on the interval [lo, hi]. Commonly used for bounded astrophysical quantities with little prior knowledge. Reparameterizes via sigmoid to ensure differentiability and automatic bound satisfaction.
- Parameters:
- Raises:
ValueError – If lo >= hi.
Notes
JIT-compatible: yes — all operations use
jnpprimitives.Standardization: Maps ξ ~ N(0,1) to θ via the sigmoid function:
\[\theta = lo + (hi - lo) \cdot \sigma(\xi)\]where σ(ξ) = 1 / (1 + exp(-ξ)). At ξ = 0 (prior center), θ = (lo + hi) / 2. This ensures automatic bound satisfaction and smooth gradients.
Examples
>>> import jax.random >>> from tengri import Uniform >>> prior = Uniform(0, 1) >>> key = jax.random.PRNGKey(0) >>> sample = prior.sample(key) >>> print(f"Sample: {sample:.4f}") # Will be in [0, 1) >>> log_prob = prior.log_prob(0.5) >>> print(f"log p(0.5): {log_prob:.4f}") # ≈ 0.0 (log(1) = 0)
- property hi: float¶
Upper bound of the uniform distribution.
- Returns:
Upper bound value.
- Return type:
- property lo: float¶
Lower bound of the uniform distribution.
- Returns:
Lower bound value.
- Return type:
- sample(key: Array) Array[source]¶
Draw one sample uniformly from [lo, hi].
- Parameters:
key (jax.Array) – JAX PRNG key for random sampling.
- Returns:
A single sample uniformly distributed in [lo, hi].
- Return type:
ndarray
Gaussian¶
- class tengri.Gaussian(mu: float, sigma: float, lo: float = -inf, hi: float = inf)[source]¶
Bases:
DistributionGaussian (normal) prior, optionally clipped to [lo, hi].
A bell-curve probability density centered at μ with standard deviation σ. Useful when prior information suggests a most-probable value with uncertainty. Optional bounds allow truncation to physical ranges.
- Parameters:
- Raises:
ValueError – If sigma <= 0 or lo >= hi.
Notes
JIT-compatible: yes — all operations use
jnpprimitives.Standardization: Maps ξ ~ N(0,1) to θ via:
\[\theta = \text{clip}(\mu + \sigma \cdot \xi, lo, hi)\]Normalization: When lo, hi are finite, the density is normalized over [lo, hi], not over the full real line. Use this for physically bounded quantities.
Examples
>>> import jax.random >>> from tengri import Gaussian >>> prior = Gaussian(mu=-0.3, sigma=0.2) # metallicity >>> key = jax.random.PRNGKey(0) >>> sample = prior.sample(key) >>> print(f"Sample: {sample:.3f}") # Typically near -0.3 >>> log_prob = prior.log_prob(-0.3) >>> print(f"log p(μ): {log_prob:.4f}") # Maximum at mean
- property hi: float¶
Upper truncation bound.
- Returns:
Upper truncation bound (+inf if unbounded).
- Return type:
- property lo: float¶
Lower truncation bound.
- Returns:
Lower truncation bound (-inf if unbounded).
- Return type:
- sample(key: Array) Array[source]¶
Draw a random sample from the Gaussian distribution.
- Parameters:
key (jax.Array) – JAX PRNG key for random sampling.
- Returns:
A single sample from N(mu, sigma²), clipped to [lo, hi].
- Return type:
ndarray
- property sigma: float¶
Standard deviation of the Gaussian distribution.
- Returns:
Standard deviation value.
- Return type:
LogUniform¶
- class tengri.LogUniform(lo: float, hi: float)[source]¶
Bases:
DistributionUniform in log10 space on [lo, hi].
A prior that places equal probability in logarithmic intervals, resulting in power-law density in linear space. Useful for quantities with logarithmic uncertainties, such as star formation rates, timescales, and luminosities.
- Parameters:
- Raises:
ValueError – If lo <= 0 or lo >= hi.
Notes
JIT-compatible: yes — all operations use
jnpprimitives.The probability density in linear space is:
\[p(x) = \frac{1}{x \cdot \ln(10) \cdot \log_{10}(hi / lo)}\]where x ∈ [lo, hi]. Samples drawn from this distribution are equally spaced in log10 space: log10(x) ~ U(log10(lo), log10(hi)).
Standardization: Maps ξ ~ N(0,1) to θ via sigmoid in log space:
\[\theta = 10^{\log_{10}(lo) + (\log_{10}(hi) - \log_{10}(lo)) \cdot \sigma(\xi)}\]Examples
>>> import jax.random >>> from tengri import LogUniform >>> prior = LogUniform(1e-2, 1e2) # ~4 orders of magnitude >>> key = jax.random.PRNGKey(0) >>> sample = prior.sample(key) >>> print(f"Sample: {sample:.3e}") >>> log_prob = prior.log_prob(1.0) # Center of log space >>> print(f"log p(1.0): {log_prob:.4f}")
- property hi: float¶
Upper bound of the log-uniform distribution.
- Returns:
Upper bound value.
- Return type:
- property lo: float¶
Lower bound of the log-uniform distribution.
- Returns:
Lower bound value.
- Return type:
- log_prob(x: Array) Array[source]¶
Return log probability: -log(x * log(hi/lo)) inside bounds, -inf outside.
- sample(key: Array) Array[source]¶
Draw one sample log-uniformly from [lo, hi].
- Parameters:
key (jax.Array) – JAX PRNG key for random sampling.
- Returns:
A single sample log-uniformly distributed in [lo, hi].
- Return type:
ndarray
LogNormal¶
- class tengri.LogNormal(mu: float = 0.0, sigma: float = 1.0, lo: float = 0.0, hi: float = inf)[source]¶
Bases:
DistributionLog-normal prior: log(θ) ~ N(μ, σ²).
A prior suitable for positive-definite quantities with multiplicative uncertainty, such as timescales, amplitudes, and scale factors. The log of the parameter is normally distributed.
- Parameters:
- Raises:
ValueError – If sigma <= 0.
Notes
JIT-compatible: yes — all operations use
jnpprimitives.The probability density in linear space is:
\[p(\theta) = \frac{1}{\theta \sigma \sqrt{2\pi}} \exp\left( -\frac{(\ln \theta - \mu)^2}{2\sigma^2} \right)\]for θ ∈ [lo, hi]. When truncated, the density is renormalized over the interval.
Standardization: Maps ξ ~ N(0,1) to θ via:
\[\theta = \text{clip}(\exp(\mu + \sigma \cdot \xi), lo, hi)\]Examples
>>> import jax.random >>> from tengri import LogNormal >>> # PSD timescale: log(tau_yr) centered at 8, width 0.5 dex >>> prior = LogNormal(mu=8, sigma=0.5) >>> key = jax.random.PRNGKey(0) >>> sample = prior.sample(key) >>> print(f"Sample (yr): {sample:.3e}") >>> log_prob = prior.log_prob(1e8) >>> print(f"log p(1e8): {log_prob:.4f}")
- property mu: float¶
Mean of log(theta).
- Returns:
Mean of the logarithm of the parameter.
- Return type:
- sample(key: Array) Array[source]¶
Draw a random sample from the log-normal distribution.
- Parameters:
key (jax.Array) – JAX PRNG key for random sampling.
- Returns:
A single sample from LogNormal(mu, sigma²), clipped to [lo, hi].
- Return type:
ndarray
- property sigma: float¶
Standard deviation of log(theta).
- Returns:
Standard deviation of the logarithm of the parameter.
- Return type:
StudentT¶
- class tengri.StudentT(mu: float = 0.0, sigma: float = 1.0, df: float = 3.0, lo: float = -inf, hi: float = inf)[source]¶
Bases:
DistributionStudent’s t prior with heavier tails than Gaussian.
A robust prior with longer tails, useful for parameters that may exhibit outlier-like behavior. Commonly used in BAGPIPES-style SED fitting for down-weighting extreme values while remaining flexible.
- Parameters:
mu (float, optional) – Location (center) of the distribution. Default: 0.0.
sigma (float, optional) – Scale parameter. Must be positive. Default: 1.0.
df (float, optional) – Degrees of freedom. Controls tail weight: - df → ∞ gives Gaussian (heaviest concentration at center) - df = 3 gives a moderately heavy-tailed prior - df = 1 gives Cauchy (extremely heavy tails) Default: 3.0.
lo (float, optional) – Lower truncation bound. Default: -∞ (no lower truncation).
hi (float, optional) – Upper truncation bound. Default: +∞ (no upper truncation).
Notes
JIT-compatible: yes — all operations use
jnpprimitives.The probability density follows a Student’s t distribution with the standard normalisation. For finite df, it has heavier tails than a Gaussian.
Standardization: Uses a Gaussian approximation with variance scaling:
\[\theta = \text{clip}(\mu + \sigma \cdot \sqrt{df/(df-2)} \cdot \xi, lo, hi)\]This is valid for df > 2. For df ≤ 2, a fallback scale of 3 is used.
Examples
>>> import jax.random >>> from tengri import StudentT >>> prior = StudentT(mu=0, sigma=1, df=3) # Robust prior >>> key = jax.random.PRNGKey(0) >>> sample = prior.sample(key) >>> print(f"Sample: {sample:.4f}")
- sample(key: Array) Array[source]¶
Draw a random sample from the Student’s t distribution.
- Parameters:
key (jax.Array) – JAX PRNG key for random sampling.
- Returns:
A single sample from Student’s t distribution, clipped to [lo, hi].
- Return type:
ndarray
- standardize(theta: Array) Array[source]¶
Map a physical parameter value to a standardized coordinate via the Student-t scale.
Fixed¶
- class tengri.Fixed(value: float | str)[source]¶
Bases:
DistributionFixed (non-free) parameter with a constant value.
Represents a parameter that is not sampled or inferred. Used for holding model settings constant during fitting, or for categorical parameters that don’t vary. Fixed parameters contribute zero to the likelihood.
- Parameters:
value (float, int, or str) – The fixed value. Can be numeric (for quantitative parameters) or string (for categorical choices, e.g. “solar” for shock abundance).
- Returns:
Fixed instance with the given value.
- Return type:
Notes
JIT-compatible: yes —
unstandardize()returns the constant value regardless of the latent variable ξ.Inference: Fixed parameters are excluded from the inference set. They do not appear in the posterior and do not contribute to the loss or gradients.
Examples
>>> from tengri import Fixed >>> # Numerical fixed value >>> redshift = Fixed(0.1) >>> print(redshift.sample(None)) 0.1 >>> # Categorical fixed value >>> shock_abundance = Fixed("solar") >>> print(shock_abundance.sample(None)) solar
- property bounds: tuple[float, float] | tuple[None, None]¶
Return (value, value) for numeric, or (None, None) for string.
- property is_fixed: bool¶
Return True — this is a fixed (non-free) parameter.
- Returns:
Always True for Fixed distributions.
- Return type:
- log_prob(x: Array) Array[source]¶
Return 0.0 (fixed parameters have zero log-likelihood contribution).