From 8c5e83623948c3210248fdfd7c071ef63baf8bd4 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Sep 2023 16:28:27 -0500 Subject: [PATCH 1/2] Implementation of Maxwell RV --- doc/library/tensor/random/basic.rst | 3 ++ pytensor/link/jax/dispatch/random.py | 1 + pytensor/link/numba/dispatch/random.py | 1 + pytensor/tensor/random/basic.py | 58 ++++++++++++++++++++++++++ tests/link/jax/test_random.py | 16 +++++++ tests/link/numba/test_random.py | 16 +++++++ tests/tensor/random/test_basic.py | 19 +++++++++ 7 files changed, 114 insertions(+) diff --git a/doc/library/tensor/random/basic.rst b/doc/library/tensor/random/basic.rst index 461fdf59be..0251e6d7a6 100644 --- a/doc/library/tensor/random/basic.rst +++ b/doc/library/tensor/random/basic.rst @@ -124,6 +124,9 @@ PyTensor can produce :class:`RandomVariable`\s that draw samples from many diffe .. autoclass:: pytensor.tensor.random.basic.LogNormalRV :members: __call__ +.. autoclass:: pytensor.tensor.random.basic.MaxwellRV + :members: __call__ + .. autoclass:: pytensor.tensor.random.basic.MultinomialRV :members: __call__ diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 0981234db0..aab69f109b 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -146,6 +146,7 @@ def sample_fn(rng, size, dtype, *parameters): @jax_sample_fn.register(aer.LogisticRV) @jax_sample_fn.register(aer.NormalRV) @jax_sample_fn.register(aer.StandardNormalRV) +@jax_sample_fn.register(aer.MaxwellRV) def jax_sample_fn_loc_scale(op): """JAX implementation of random variables in the loc-scale families. diff --git a/pytensor/link/numba/dispatch/random.py b/pytensor/link/numba/dispatch/random.py index 6de14cd3c5..33e3792ba5 100644 --- a/pytensor/link/numba/dispatch/random.py +++ b/pytensor/link/numba/dispatch/random.py @@ -194,6 +194,7 @@ def {sized_fn_name}({random_fn_input_names}): @numba_funcify.register(aer.BetaRV) @numba_funcify.register(aer.NormalRV) @numba_funcify.register(aer.LogNormalRV) +@numba_funcify.register(aer.MaxwellRV) @numba_funcify.register(aer.GammaRV) @numba_funcify.register(aer.ChiSquareRV) @numba_funcify.register(aer.ParetoRV) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 96c7913336..5f4d3f805f 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -419,6 +419,63 @@ def __call__(self, mean=0.0, sigma=1.0, size=None, **kwargs): lognormal = LogNormalRV() +class MaxwellRV(ScipyRandomVariable): + r"""A Maxwellian continuous random variable. + + The probability density function for `maxwell` in terms of its parameters :math:`\mu` + and :math:`\sigma` is: + + .. math:: + + f(x; \mu, \sigma) = \sqrt{\frac{2}{\pi}}\frac{(x-\mu)^2 \exp\left\{-(x-\mu)^2/(2\sigma^2)\}}{\sigma^3} + + for :math:`x \geq 0` and :math:`\sigma > 0` + + """ + name = "maxwell" + ndim_supp = 0 + ndims_params = [0, 0] + dtype = "floatX" + _print_name = ("Maxwell", "\\operatorname{Maxwell}") + + def __call__(self, loc, scale, size=None, **kwargs): + r"""Draw samples from a Maxwell distribution. + + Signature + --------- + + `(), () -> ()` + + Parameters + ---------- + loc + Location parameter :math:`\mu` of the distribution. + scale + Scale parameter :math:`\sigma` of the distribution. Must be + positive. + size + Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k` + independent, identically distributed random variables are + returned. Default is `None` in which case a single random variable + is returned. + + """ + return super().__call__(loc, scale, size=size, **kwargs) + + @classmethod + def rng_fn_scipy( + cls, + rng: Union[np.random.Generator, np.random.RandomState], + loc: Union[np.ndarray, float], + scale: Union[np.ndarray, float], + size: Optional[Union[List[int], int]], + ) -> np.ndarray: + return stats.maxwell.rvs(loc=loc, scale=scale, size=size, random_state=rng) + + +maxwell = MaxwellRV() + + class GammaRV(ScipyRandomVariable): r"""A gamma continuous random variable. @@ -2157,6 +2214,7 @@ def permutation(x, **kwargs): "lognormal", "halfnormal", "normal", + "maxwell", "beta", "triangular", "uniform", diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 54e4e09307..c1ae8107f5 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -213,6 +213,22 @@ def test_random_updates_input_storage_order(): "lognorm", lambda mu, sigma: (sigma, 0, np.exp(mu)), ), + ( + aer.maxwell, + [ + set_test_value( + at.lvector(), + np.array([1, 2], dtype=np.int64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "maxwell", + lambda *args: args, + ), ( aer.normal, [ diff --git a/tests/link/numba/test_random.py b/tests/link/numba/test_random.py index f0ddf3525f..db6947bcd2 100644 --- a/tests/link/numba/test_random.py +++ b/tests/link/numba/test_random.py @@ -85,6 +85,22 @@ ], at.as_tensor([3, 2]), ), + ( + aer.maxwell, + [ + set_test_value( + at.dvector(), + np.array([1.0, 2.0], dtype=np.float64), + ), + set_test_value( + at.dscalar(), + np.array(1.0, dtype=np.float64), + ), + ], + (2,), + "maxwell", + lambda *args: args, + ), pytest.param( aer.pareto, [ diff --git a/tests/tensor/random/test_basic.py b/tests/tensor/random/test_basic.py index 4032b9a673..1253c8bc97 100644 --- a/tests/tensor/random/test_basic.py +++ b/tests/tensor/random/test_basic.py @@ -40,6 +40,7 @@ laplace, logistic, lognormal, + maxwell, multinomial, multivariate_normal, nbinom, @@ -338,6 +339,24 @@ def test_lognormal_samples(mean, sigma, size): compare_sample_values(lognormal, mean, sigma, size=size) +@pytest.mark.parametrize( + "loc, sigma, size", + [ + (np.array(0, dtype=config.floatX), np.array(1, dtype=config.floatX), None), + (np.array(0, dtype=config.floatX), np.array(1, dtype=config.floatX), []), + ( + np.full((1, 2), 0, dtype=config.floatX), + np.array(1, dtype=config.floatX), + None, + ), + ], +) +def test_maxwell_samples(loc, sigma, size): + compare_sample_values( + maxwell, loc, sigma, size=size, test_fn=fixed_scipy_rvs("maxwell") + ) + + @pytest.mark.parametrize( "a, b, size", [ From 0095a9b521310de955cda1c3c3e0b376a9bbbe10 Mon Sep 17 00:00:00 2001 From: Trey Wenger Date: Tue, 12 Sep 2023 16:52:25 -0500 Subject: [PATCH 2/2] update docstring --- pytensor/tensor/random/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/tensor/random/basic.py b/pytensor/tensor/random/basic.py index 5f4d3f805f..e48bb81979 100644 --- a/pytensor/tensor/random/basic.py +++ b/pytensor/tensor/random/basic.py @@ -427,7 +427,7 @@ class MaxwellRV(ScipyRandomVariable): .. math:: - f(x; \mu, \sigma) = \sqrt{\frac{2}{\pi}}\frac{(x-\mu)^2 \exp\left\{-(x-\mu)^2/(2\sigma^2)\}}{\sigma^3} + f(x; \mu, \sigma) = \sqrt{\frac{2}{\pi}}\frac{(x-\mu)^2 e^{-(x-\mu)^2/(2\sigma^2)}}{\sigma^3} for :math:`x \geq 0` and :math:`\sigma > 0`