From d9b7ef41ef2ec26351cc647bc18b1d321404711f Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 3 Dec 2024 14:41:18 +0100 Subject: [PATCH] Implement Ordered distribution factory --- pymc/distributions/__init__.py | 2 + pymc/distributions/discrete.py | 6 +- pymc/distributions/ordered.py | 133 +++++++++++++++++++++++++++++++++ pymc/logprob/order.py | 126 ++++++++++++++++++++++++------- 4 files changed, 238 insertions(+), 29 deletions(-) create mode 100644 pymc/distributions/ordered.py diff --git a/pymc/distributions/__init__.py b/pymc/distributions/__init__.py index 442ebddc712..9b8d27f87c9 100644 --- a/pymc/distributions/__init__.py +++ b/pymc/distributions/__init__.py @@ -103,6 +103,7 @@ WishartBartlett, ZeroSumNormal, ) +from pymc.distributions.ordered import Ordered from pymc.distributions.simulator import Simulator from pymc.distributions.timeseries import ( AR, @@ -178,6 +179,7 @@ "NegativeBinomial", "Normal", "NormalMixture", + "Ordered", "OrderedLogistic", "OrderedMultinomial", "OrderedProbit", diff --git a/pymc/distributions/discrete.py b/pymc/distributions/discrete.py index 979f81dba0c..4ebf399e9e3 100644 --- a/pymc/distributions/discrete.py +++ b/pymc/distributions/discrete.py @@ -1239,8 +1239,7 @@ class OrderedLogistic: # Ordered logistic regression with pm.Model() as model: - cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2, - transform=pm.distributions.transforms.ordered) + cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2) y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y) idata = pm.sample() @@ -1343,8 +1342,7 @@ class OrderedProbit: # Ordered probit regression with pm.Model() as model: - cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2, - transform=pm.distributions.transforms.ordered) + cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2) y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y) idata = pm.sample() diff --git a/pymc/distributions/ordered.py b/pymc/distributions/ordered.py new file mode 100644 index 00000000000..2aa9e4da88e --- /dev/null +++ b/pymc/distributions/ordered.py @@ -0,0 +1,133 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytensor.tensor as pt + +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.random.utils import normalize_size_param +from pytensor.tensor.variable import TensorVariable + +from pymc.distributions.distribution import ( + Distribution, + SymbolicRandomVariable, + _support_point, +) +from pymc.distributions.shape_utils import change_dist_size, get_support_shape_1d, rv_size_is_none +from pymc.distributions.transforms import _default_transform, ordered + + +class OrderedRV(SymbolicRandomVariable): + inline_logprob = True + extended_signature = "(x)->(x)" + _print_name = ("Ordered", "\\operatorname{Ordered}") + + @classmethod + def rv_op(cls, dist, *, size=None): + # We don't allow passing `rng` because we don't fully control the rng of the components! + + size = normalize_size_param(size) + + if not rv_size_is_none(size): + core_shape = tuple(dist.shape)[-1] + shape = (*tuple(size), core_shape) + dist = change_dist_size(dist, shape) + + sorted_rv = pt.sort(dist, axis=-1) + + return OrderedRV( + inputs=[dist], + outputs=[sorted_rv], + )(dist) + + +class Ordered(Distribution): + r"""Univariate IID Ordered distribution. + + The pdf of the oredered distribution is + + .. math:: + f(x_1, ..., x_n) = n!\prod_{i=1}^n f(x_{(i)}), + where x_1 <= x2 <= ... <= x_n + + Parameters + ---------- + dist: unnamed_distribution + Univariate IID distribution which will be sorted. + + .. warning:: dist will be cloned, rendering it independent of the one passade as input + + Examples + -------- + .. code-block:: python + import pymc as pm + + with pm.Model(): + x = pm.Normal.dist(mu=0, sigma=1) # Must be IID + ordered_x = pm.Ordered("ordered_x", dist=x, shape=(3,)) + + pm.draw(ordered_x, random_seed=52) # array([0.05172346, 0.43970706, 0.91500416]) + """ + + rv_type = OrderedRV + rv_op = OrderedRV.rv_op + + def __new__(cls, name, dist, *, support_shape=None, **kwargs): + support_shape = get_support_shape_1d( + support_shape=support_shape, + shape=None, # shape will be checked in `cls.dist` + dims=kwargs.get("dims", None), + observed=kwargs.get("observed", None), + ) + return super().__new__(cls, name, dist, support_shape=support_shape, **kwargs) + + @classmethod + def dist(cls, dist, *, support_shape=None, **kwargs): + if not isinstance(dist, TensorVariable) or not isinstance( + dist.owner.op, RandomVariable | SymbolicRandomVariable + ): + raise ValueError( + f"Ordered dist must be a distribution created via the `.dist()` API, got {type(dist)}" + ) + if dist.owner.op.ndim_supp > 0: + raise NotImplementedError("Ordering of multivariate distributions not supported") + if not all( + all(param.type.broadcastable) for param in dist.owner.op.dist_params(dist.owner) + ): + raise ValueError("Ordered dist must be an IID variable") + + support_shape = get_support_shape_1d( + support_shape=support_shape, + shape=kwargs.get("shape", None), + ) + if support_shape is not None: + dist = change_dist_size(dist, support_shape) + + dist = pt.atleast_1d(dist) + + return super().dist([dist], **kwargs) + + +@_default_transform.register(OrderedRV) +def default_transform_ordered(op, rv): + if rv.type.dtype.startswith("float"): + return ordered + else: + return None + + +@_support_point.register(OrderedRV) +def support_point_ordered(op, rv, dist): + # FIXME: This does not work with the default ordered transform + # which maps [0, 0, 0] to [0, -inf, -inf]. + # return support_point(dist) + return rv # Draw from the prior diff --git a/pymc/logprob/order.py b/pymc/logprob/order.py index 6eceb819dd8..faa97c72085 100644 --- a/pymc/logprob/order.py +++ b/pymc/logprob/order.py @@ -41,21 +41,65 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import node_rewriter from pytensor.tensor.math import Max +from pytensor.tensor.random.op import RandomVariable +from pytensor.tensor.sort import SortOp from pytensor.tensor.variable import TensorVariable from pymc.logprob.abstract import ( - MeasurableElemwise, MeasurableOp, _logcdf_helper, _logprob, _logprob_helper, ) from pymc.logprob.rewriting import measurable_ir_rewrites_db -from pymc.logprob.utils import filter_measurable_variables +from pymc.logprob.utils import ( + CheckParameterValue, + check_potential_measurability, + filter_measurable_variables, +) from pymc.math import logdiffexp from pymc.pytensorf import constant_fold +def _underlying_iid_rv(variable) -> TensorVariable | None: + # Check whether an IID base RV is connected to the variable through identical elemwise operations + from pymc.distributions.distribution import SymbolicRandomVariable + from pymc.logprob.transforms import MeasurableTransform + + def iid_elemwise_root(var: TensorVariable) -> TensorVariable | None: + node = var.owner + if isinstance(node.op, RandomVariable | SymbolicRandomVariable): + return var + elif isinstance(node.op, MeasurableTransform): + if len(node.inputs == 1): + return iid_elemwise_root(node.inputs[0]) + else: + # If the non-measurable inputs are broadcasted, it is still an IID operation. + measurable_inp = node.op.measurable_input_idx + other_inputs = [inp for i, inp in node.inputs if i != measurable_inp] + if all(all(other_inp.type.broadcastable) for other_inp in other_inputs): + return iid_elemwise_root(node.inputs[measurable_inp]) + return None + + # Check that the root is a univariate distribution linked by only elemwise operations + latent_base_var = iid_elemwise_root(variable) + + if latent_base_var is None: + return None + + latent_op = latent_base_var.owner.op + + if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0): + return None + + if not all( + all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner) + ): + return None + + return cast(TensorVariable, latent_base_var) + + class MeasurableMax(MeasurableOp, Max): """A placeholder used to specify a log-likelihood for a max sub-graph.""" @@ -77,31 +121,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab if not filter_measurable_variables(node.inputs): return None - # We allow Max of RandomVariables or Elemwise of univariate RandomVariables - if isinstance(base_var.owner.op, MeasurableElemwise): - latent_base_vars = [ - var - for var in base_var.owner.inputs - if (var.owner and isinstance(var.owner.op, MeasurableOp)) - ] - if len(latent_base_vars) != 1: - return None - [latent_base_var] = latent_base_vars - else: - latent_base_var = base_var - - latent_op = latent_base_var.owner.op - if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0): - return None + # We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables + latent_base_var = _underlying_iid_rv(base_var) - # univariate i.i.d. test which also rules out other distributions - if not all( - all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner) - ): + if not latent_base_var: return None - base_var = cast(TensorVariable, base_var) - if node.op.axis is None: axis = tuple(range(base_var.ndim)) else: @@ -119,7 +144,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab measurable_ir_rewrites_db.register( - "find_measurable_max", + find_measurable_max.__name__, find_measurable_max, "basic", "max", @@ -158,3 +183,54 @@ def max_logprob_discrete(op, values, base_rv, **kwargs): n = pt.prod(base_rv_shape) return logdiffexp(n * logcdf, n * logcdf_prev) + + +class MeasurableSort(MeasurableOp, SortOp): + """A placeholder used to specify a log-likelihood for a sort sub-graph.""" + + +@_logprob.register(MeasurableSort) +def sort_logprob(op, values, base_rv, axis, **kwargs): + r"""Compute the log-likelihood graph for the `Sort` operation.""" + (value,) = values + + logprob = _logprob_helper(base_rv, value).sum(axis=-1) + + base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False) + n = pt.prod(base_rv_shape, axis=-1) + sorted_logp = pt.gammaln(n + 1) + logprob + + # The sorted value is not really a parameter, but we include the check in + # `CheckParameterValue` to avoid costly sorting if `check_bounds=False` in a PyMC model + return CheckParameterValue("value must be sorted", can_be_replaced_by_ninf=True)( + sorted_logp, pt.eq(value, value.sort(axis=axis, kind=op.kind)).all() + ) + + +@node_rewriter(tracks=[SortOp]) +def find_measurable_sort(fgraph, node): + if isinstance(node.op, MeasurableSort): + return None + + if not filter_measurable_variables(node.inputs): + return None + + [base_var, axis] = node.inputs + + # We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables + if _underlying_iid_rv(base_var) is None: + return None + + # Check axis is not potentially measurable + if check_potential_measurability([axis]): + return None + + return [MeasurableSort(**node.op._props_dict())(base_var, axis)] + + +measurable_ir_rewrites_db.register( + find_measurable_sort.__name__, + find_measurable_sort, + "basic", + "sort", +)