Skip to content

Commit

Permalink
Implement Ordered distribution factory
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 3, 2024
1 parent 3f3aeb9 commit 7020a27
Show file tree
Hide file tree
Showing 4 changed files with 238 additions and 29 deletions.
2 changes: 2 additions & 0 deletions pymc/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
WishartBartlett,
ZeroSumNormal,
)
from pymc.distributions.ordered import Ordered
from pymc.distributions.simulator import Simulator
from pymc.distributions.timeseries import (
AR,
Expand Down Expand Up @@ -178,6 +179,7 @@
"NegativeBinomial",
"Normal",
"NormalMixture",
"Ordered",
"OrderedLogistic",
"OrderedMultinomial",
"OrderedProbit",
Expand Down
6 changes: 2 additions & 4 deletions pymc/distributions/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,8 +1239,7 @@ class OrderedLogistic:
# Ordered logistic regression
with pm.Model() as model:
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.ordered)
cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2)
y_ = pm.OrderedLogistic("y", cutpoints=cutpoints, eta=x, observed=y)
idata = pm.sample()
Expand Down Expand Up @@ -1343,8 +1342,7 @@ class OrderedProbit:
# Ordered probit regression
with pm.Model() as model:
cutpoints = pm.Normal("cutpoints", mu=[-1,1], sigma=10, shape=2,
transform=pm.distributions.transforms.ordered)
cutpoints = pm.Ordered("cutpoints", dist=pm.Normal.dist(mu=0, sigma=10), shape=2)
y_ = pm.OrderedProbit("y", cutpoints=cutpoints, eta=x, observed=y)
idata = pm.sample()
Expand Down
133 changes: 133 additions & 0 deletions pymc/distributions/ordered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytensor.tensor as pt

from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.utils import normalize_size_param
from pytensor.tensor.variable import TensorVariable

from pymc.distributions.distribution import (
Distribution,
SymbolicRandomVariable,
_support_point,
)
from pymc.distributions.shape_utils import change_dist_size, get_support_shape_1d, rv_size_is_none
from pymc.distributions.transforms import _default_transform, ordered


class OrderedRV(SymbolicRandomVariable):
inline_logprob = True
extended_signature = "(x)->(x)"
_print_name = ("Ordered", "\\operatorname{Ordered}")

@classmethod
def rv_op(cls, dist, *, size=None):
# We don't allow passing `rng` because we don't fully control the rng of the components!

size = normalize_size_param(size)

if not rv_size_is_none(size):
core_shape = tuple(dist.shape)[-1]
shape = (*tuple(size), core_shape)
dist = change_dist_size(dist, shape)

sorted_rv = pt.sort(dist, axis=-1)

return OrderedRV(
inputs=[dist],
outputs=[sorted_rv],
)(dist)


class Ordered(Distribution):
r"""Univariate IID Ordered distribution.
The pdf of the oredered distribution is
.. math::
f(x_1, ..., x_n) = n!\prod_{i=1}^n f(x_{(i)}),
where x_1 <= x2 <= ... <= x_n
Parameters
----------
dist: unnamed_distribution
Univariate IID distribution which will be sorted.
.. warning:: dist will be cloned, rendering it independent of the one passade as input
Examples
--------
.. code-block:: python
import pymc as pm
with pm.Model():
x = pm.Normal.dist(mu=0, sigma=1) # Must be IID
ordered_x = pm.Ordered("ordered_x", dist=x, shape=(3,))
pm.draw(ordered_x, random_seed=52) # array([0.05172346, 0.43970706, 0.91500416])
"""

rv_type = OrderedRV
rv_op = OrderedRV.rv_op

def __new__(cls, name, dist, *, support_shape=None, **kwargs):
support_shape = get_support_shape_1d(
support_shape=support_shape,
shape=None, # shape will be checked in `cls.dist`
dims=kwargs.get("dims", None),
observed=kwargs.get("observed", None),
)
return super().__new__(cls, name, dist, support_shape=support_shape, **kwargs)

@classmethod
def dist(cls, dist, *, support_shape=None, **kwargs):
if not isinstance(dist, TensorVariable) or not isinstance(
dist.owner.op, RandomVariable | SymbolicRandomVariable
):
raise ValueError(
f"Ordered dist must be a distribution created via the `.dist()` API, got {type(dist)}"
)
if dist.owner.op.ndim_supp > 0:
raise NotImplementedError("Ordering of multivariate distributions not supported")
if not all(
all(param.type.broadcastable) for param in dist.owner.op.dist_params(dist.owner)
):
raise ValueError("Ordered dist must be an IID variable")

support_shape = get_support_shape_1d(
support_shape=support_shape,
shape=kwargs.get("shape", None),
)
if support_shape is not None:
dist = change_dist_size(dist, support_shape)

dist = pt.atleast_1d(dist)

return super().dist([dist], **kwargs)


@_default_transform.register(OrderedRV)
def default_transform_ordered(op, rv):
if rv.type.dtype.startswith("float"):
return ordered
else:
return None


@_support_point.register(OrderedRV)
def support_point_ordered(op, rv, dist):
# FIXME: This does not work with the default ordered transform
# It leads to a -inf log_jac_det
# return support_point(dist)
return rv # Draw from the prior
126 changes: 101 additions & 25 deletions pymc/logprob/order.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,65 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor.math import Max
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.sort import SortOp
from pytensor.tensor.variable import TensorVariable

from pymc.logprob.abstract import (
MeasurableElemwise,
MeasurableOp,
_logcdf_helper,
_logprob,
_logprob_helper,
)
from pymc.logprob.rewriting import measurable_ir_rewrites_db
from pymc.logprob.utils import filter_measurable_variables
from pymc.logprob.utils import (
CheckParameterValue,
check_potential_measurability,
filter_measurable_variables,
)
from pymc.math import logdiffexp
from pymc.pytensorf import constant_fold


def _underlying_iid_rv(variable) -> TensorVariable | None:
# Check whether an IID base RV is connected to the variable through identical elemwise operations
from pymc.distributions.distribution import SymbolicRandomVariable
from pymc.logprob.transforms import MeasurableTransform

def iid_elemwise_root(var: TensorVariable) -> TensorVariable | None:
node = var.owner
if isinstance(node.op, RandomVariable | SymbolicRandomVariable):
return var
elif isinstance(node.op, MeasurableTransform):
if len(node.inputs == 1):
return iid_elemwise_root(node.inputs[0])
else:
# If the non-measurable inputs are broadcasted, it is still an IID operation.
measurable_inp = node.op.measurable_input_idx
other_inputs = [inp for i, inp in node.inputs if i != measurable_inp]
if all(all(other_inp.type.broadcastable) for other_inp in other_inputs):
return iid_elemwise_root(node.inputs[measurable_inp])
return None

# Check that the root is a univariate distribution linked by only elemwise operations
latent_base_var = iid_elemwise_root(variable)

if latent_base_var is None:
return None

latent_op = latent_base_var.owner.op

if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
return None

if not all(
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
):
return None

return cast(TensorVariable, latent_base_var)


class MeasurableMax(MeasurableOp, Max):
"""A placeholder used to specify a log-likelihood for a max sub-graph."""

Expand All @@ -77,31 +121,12 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab
if not filter_measurable_variables(node.inputs):
return None

# We allow Max of RandomVariables or Elemwise of univariate RandomVariables
if isinstance(base_var.owner.op, MeasurableElemwise):
latent_base_vars = [
var
for var in base_var.owner.inputs
if (var.owner and isinstance(var.owner.op, MeasurableOp))
]
if len(latent_base_vars) != 1:
return None
[latent_base_var] = latent_base_vars
else:
latent_base_var = base_var

latent_op = latent_base_var.owner.op
if not (hasattr(latent_op, "dist_params") and getattr(latent_op, "ndim_supp") == 0):
return None
# We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
latent_base_var = _underlying_iid_rv(base_var)

# univariate i.i.d. test which also rules out other distributions
if not all(
all(params.type.broadcastable) for params in latent_op.dist_params(latent_base_var.owner)
):
if not latent_base_var:
return None

base_var = cast(TensorVariable, base_var)

if node.op.axis is None:
axis = tuple(range(base_var.ndim))
else:
Expand All @@ -119,7 +144,7 @@ def find_measurable_max(fgraph: FunctionGraph, node: Apply) -> list[TensorVariab


measurable_ir_rewrites_db.register(
"find_measurable_max",
find_measurable_max.__name__,
find_measurable_max,
"basic",
"max",
Expand Down Expand Up @@ -158,3 +183,54 @@ def max_logprob_discrete(op, values, base_rv, **kwargs):

n = pt.prod(base_rv_shape)
return logdiffexp(n * logcdf, n * logcdf_prev)


class MeasurableSort(MeasurableOp, SortOp):
"""A placeholder used to specify a log-likelihood for a sort sub-graph."""


@_logprob.register(MeasurableSort)
def sort_logprob(op, values, base_rv, axis, **kwargs):
r"""Compute the log-likelihood graph for the `Sort` operation."""
(value,) = values

logprob = _logprob_helper(base_rv, value).sum(axis=-1)

base_rv_shape = constant_fold(tuple(base_rv.shape), raise_not_constant=False)
n = pt.prod(base_rv_shape, axis=-1)
sorted_logp = pt.log(n) + logprob

# The sorted value is not really a parameter, but we include the check in
# `CheckParameterValue` to allow costly sorting if `check_bounds=False` in a PyMC model
return CheckParameterValue("value must be sorted", can_be_replaced_by_ninf=True)(
sorted_logp, pt.eq(value, value.sort(axis=axis, kind=op.kind)).all()
)


@node_rewriter(tracks=[SortOp])
def find_measurable_sort(fgraph, node):
if isinstance(node.op, MeasurableSort):
return None

if not filter_measurable_variables(node.inputs):
return None

[base_var, axis] = node.inputs

# We allow Max of RandomVariables or IID Elemwise of univariate RandomVariables
if _underlying_iid_rv(base_var) is None:
return None

# Check axis is not potentially measurable
if check_potential_measurability([axis]):
return None

return [MeasurableSort(**node.op._props_dict())(base_var, axis)]


measurable_ir_rewrites_db.register(
find_measurable_sort.__name__,
find_measurable_sort,
"basic",
"sort",
)

0 comments on commit 7020a27

Please sign in to comment.