# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Specialized priors that behave like the Prior class.
The Prior class has certain design constraints that prevent it from
covering all cases. So this module contains a collection of
priors that do not inherit from the Prior class but have many
of the same methods.
"""
import warnings
from typing import Any
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pymc_extras.deserialize import deserialize, register_deserialization
from pymc_extras.prior import Prior, VariableFactory, create_dim_handler, sample_prior
from pytensor.tensor import TensorVariable
[docs]
class LogNormalPrior:
r"""Lognormal prior parameterized by positive-scale mean and std.
A lognormal prior parameterized by mean and standard deviation
on the positive domain, with optional centered or non-centered
parameterization.
This prior differs from the standard ``LogNormal`` distribution,
which takes log-scale parameters (``mu_log``, ``sigma_log``).
Instead, it is parameterized directly in terms of the mean and
standard deviation (``mean``, ``std``) on the positive scale, making it more intuitive
and suitable for hierarchical modeling.
To achieve this, the lognormal parameters are computed internally
from the positive-domain parameters:
.. math::
\mu_{\log} &= \ln \left( \frac{\mean^2}{\sqrt{\mean^2 + \std^2}} \right) \\
\sigma_{\log} &= \sqrt{ \ln \left( 1 + \frac{\std^2}{\mean^2} \right) }
where :math:`\\mean > 0` and :math:`\\std > 0`.
The prior is then defined as:
.. math::
\\phi &\\sim \text{LogNormal}(\\mu_{\\log}, \\sigma_{\\log})
This construction ensures that the resulting random variable
has approximately the intended mean and variance on the positive scale,
even when :math:`\\mean` and :math:`\\std` are themselves random variables.
Parameters
----------
mean : Prior, float, int, array-like
The mean of the distribution on the positive scale.
std : Prior, float, int, array-like
The standard deviation of the distribution on the positive scale.
dims : tuple[str, ...], optional
The dimensions of the distribution, by default None.
centered : bool, optional
Whether to use the centered parameterization, by default True.
Examples
--------
Build a non-centered hierarchical model where information is shared across groups:
.. code-block:: python
from pymc_marketing.special_priors import LogNormalPrior
prior = LogNormalPrior(
mean=Prior("Gamma", mu=1.0, sigma=1.0),
std=Prior("HalfNormal", sigma=1.0),
dims=("geo",),
centered=False,
)
References
----------
- D. Saunders, *A positive constrained non-centered prior that sparks joy*.
- Wikipedia, *Log-normal distribution — Definitions*.
"""
[docs]
def __init__(self, dims: tuple | None = None, centered: bool = True, **parameters):
# Accept aliases mu->mean and sigma->std for convenience/compatibility
if "mean" not in parameters and "mu" in parameters:
parameters["mean"] = parameters.pop("mu")
if "std" not in parameters and "sigma" in parameters:
parameters["std"] = parameters.pop("sigma")
self.parameters = parameters
self.dims = dims
self.centered = centered
self._checks()
def _checks(self) -> None:
self._parameters_are_correct_set()
def _parameters_are_correct_set(self) -> None:
# Only allow exactly these keys after alias normalization
if set(self.parameters.keys()) != {"mean", "std"}:
raise ValueError("Parameters must be mean and std")
def _create_parameter(self, param, value, name):
if not hasattr(value, "create_variable"):
return value
child_name = f"{name}_{param}"
return self.dim_handler(value.create_variable(child_name), value.dims)
[docs]
def create_variable(self, name: str) -> TensorVariable:
"""Create a variable from the prior distribution."""
self.dim_handler = create_dim_handler(self.dims)
parameters = {
param: self._create_parameter(param, value, name)
for param, value in self.parameters.items()
}
mu_log = pt.log(
parameters["mean"] ** 2
/ pt.sqrt(parameters["mean"] ** 2 + parameters["std"] ** 2)
)
sigma_log = pt.sqrt(
pt.log(1 + (parameters["std"] ** 2 / parameters["mean"] ** 2))
)
if self.centered:
log_phi = pm.Normal(
name + "_log", mu=mu_log, sigma=sigma_log, dims=self.dims
)
else:
log_phi_z = pm.Normal(
name + "_log" + "_offset", mu=0, sigma=1, dims=self.dims
)
log_phi = mu_log + log_phi_z * sigma_log
phi = pm.math.exp(log_phi)
phi = pm.Deterministic(name, phi, dims=self.dims)
return phi
[docs]
def to_dict(self):
"""Convert the prior distribution to a dictionary."""
data = {
"special_prior": "LogNormalPrior",
}
if self.parameters:
def handle_value(value):
if isinstance(value, Prior):
return value.to_dict()
if isinstance(value, pt.TensorVariable):
value = value.eval()
if isinstance(value, np.ndarray):
return value.tolist()
if hasattr(value, "to_dict"):
return value.to_dict()
return value
data["kwargs"] = {
param: handle_value(value) for param, value in self.parameters.items()
}
if not self.centered:
data["centered"] = False
if self.dims:
data["dims"] = self.dims
return data
[docs]
@classmethod
def from_dict(cls, data) -> Prior:
"""Create a LogNormalPrior prior from a dictionary."""
if not isinstance(data, dict):
msg = (
"Must be a dictionary representation of a prior distribution. "
f"Not of type: {type(data)}"
)
raise ValueError(msg)
kwargs = data.get("kwargs", {})
def handle_value(value):
if isinstance(value, dict):
return deserialize(value)
if isinstance(value, list):
return np.array(value)
return value
kwargs = {param: handle_value(value) for param, value in kwargs.items()}
centered = data.get("centered", True)
dims = data.get("dims")
if isinstance(dims, list):
dims = tuple(dims)
return cls(dims=dims, centered=centered, **kwargs)
[docs]
def sample_prior(
self,
coords=None,
name: str = "variable",
**sample_prior_predictive_kwargs,
) -> xr.Dataset:
"""Sample from the prior distribution."""
return sample_prior(
factory=self,
coords=coords,
name=name,
**sample_prior_predictive_kwargs,
)
def _is_LogNormalPrior_type(data: dict) -> bool:
if "special_prior" in data:
return data["special_prior"] == "LogNormalPrior"
else:
return False
register_deserialization(
is_type=_is_LogNormalPrior_type,
deserialize=LogNormalPrior.from_dict,
)
[docs]
class MaskedPrior:
"""Create variables from a prior over only the active entries of a boolean mask.
.. warning::
This class is experimental and its API may change in future versions.
Parameters
----------
prior : Prior
Base prior whose variable is defined over `prior.dims`. Internally, the
variable is created only for the active entries given by `mask` and
then expanded back to the full shape with zeros at inactive positions.
mask : xarray.DataArray
Boolean array with the same dims and shape as `prior.dims` marking active
(True) and inactive (False) entries.
active_dim : str, optional
Name of the coordinate indexing the active subset. If not provided, a
name is generated as ``"non_null_dims:<dim1>_<dim2>_..."``. If an existing
coordinate with the same name has a different length, a suffix with the
active length is appended.
Examples
--------
Simple 1D masking.
.. code-block:: python
import numpy as np
import xarray as xr
import pymc as pm
from pymc_extras.prior import Prior
from pymc_marketing.special_priors import MaskedPrior
coords = {"country": ["Venezuela", "Colombia"]}
mask = xr.DataArray(
[True, False],
dims=["country"],
coords={"country": coords["country"]},
)
intercept = Prior("Normal", mu=0, sigma=10, dims=("country",))
with pm.Model(coords=coords):
masked = MaskedPrior(intercept, mask)
intercept_full = masked.create_variable("intercept")
Nested parameter priors with dims remapped to the active subset.
.. code-block:: python
import numpy as np
import xarray as xr
import pymc as pm
from pymc_extras.prior import Prior
from pymc_marketing.special_priors import MaskedPrior
coords = {"country": ["Venezuela", "Colombia"]}
mask = xr.DataArray(
[True, False],
dims=["country"],
coords={"country": coords["country"]},
)
intercept = Prior(
"Normal",
mu=Prior("HalfNormal", sigma=1, dims=("country",)),
sigma=10,
dims=("country",),
)
with pm.Model(coords=coords):
masked = MaskedPrior(intercept, mask)
intercept_full = masked.create_variable("intercept")
All entries masked (returns deterministic zeros with original dims).
.. code-block:: python
import numpy as np
import xarray as xr
import pymc as pm
from pymc_extras.prior import Prior
from pymc_marketing.special_priors import MaskedPrior
coords = {"country": ["Venezuela", "Colombia"]}
mask = xr.DataArray(
[False, False],
dims=["country"],
coords={"country": coords["country"]},
)
prior = Prior("Normal", mu=0, sigma=10, dims=("country",))
with pm.Model(coords=coords):
masked = MaskedPrior(prior, mask)
zeros = masked.create_variable("intercept")
Apply over a saturation function priors:
.. code-block:: python
from pymc_marketing.mmm import LogisticSaturation
from pymc_marketing.special_priors import MaskedPrior
coords = {
"country": ["Colombia", "Venezuela"],
"channel": ["x1", "x2", "x3", "x4"],
}
mask_excluded_x4_colombia = xr.DataArray(
[[True, False, True, False], [True, True, True, True]],
dims=["country", "channel"],
coords=coords,
)
saturation = LogisticSaturation(
priors={
"lam": MaskedPrior(
Prior(
"Gamma",
mu=2,
sigma=0.5,
dims=("country", "channel"),
),
mask=mask_excluded_x4_colombia,
),
"beta": Prior(
"Gamma",
mu=3,
sigma=0.5,
dims=("country", "channel"),
),
}
)
prior = saturation.sample_prior(coords=coords, random_seed=10)
curve = saturation.sample_curve(prior)
saturation.plot_curve(
curve,
subplot_kwargs={
"ncols": 4,
"figsize": (12, 18),
},
)
Masked likelihood over an arbitrary subset of entries (2D example over (date, country)):
.. code-block:: python
import numpy as np
import xarray as xr
import pymc as pm
from pymc_extras.prior import Prior
from pymc_marketing.special_priors import MaskedPrior
coords = {
"date": np.array(["2021-01-01", "2021-01-02"], dtype="datetime64[D]"),
"country": ["Venezuela", "Colombia"],
}
mask = xr.DataArray(
[[True, False], [True, False]],
dims=["date", "country"],
coords={"date": coords["date"], "country": coords["country"]},
)
intercept = Prior("Normal", mu=0, sigma=10, dims=("country",))
likelihood = Prior(
"Normal", sigma=Prior("HalfNormal", sigma=1), dims=("date", "country")
)
observed = np.random.normal(0, 1, size=(2, 2))
with pm.Model(coords=coords):
mu = intercept.create_variable("intercept")
masked = MaskedPrior(likelihood, mask)
y = masked.create_likelihood_variable("y", mu=mu, observed=observed)
"""
[docs]
def __init__(
self, prior: Prior, mask: xr.DataArray, active_dim: str | None = None
) -> None:
self.prior = prior
self.mask = mask
self.dims = prior.dims
self.active_dim = active_dim or f"non_null_dims:{'_'.join(self.dims)}"
self._validate_mask()
warnings.warn(
"This class is experimental and its API may change in future versions.",
stacklevel=2,
)
def _validate_mask(self) -> None:
if tuple(self.mask.dims) != tuple(self.dims):
raise ValueError("mask dims must match prior.dims order")
def _remap_dims(self, factory: VariableFactory) -> VariableFactory:
# Depth-first remap of any nested VariableFactory with dims == parent dims
# This keeps internal subset checks (_param_dims_work) satisfied.
if hasattr(factory, "parameters"):
# Recurse on child parameters first
for key, value in list(factory.parameters.items()):
if hasattr(value, "create_variable") and hasattr(value, "dims"):
factory.parameters[key] = self._remap_dims(value) # type: ignore[arg-type]
# Now remap this object's dims if they exactly match the masked dims
if hasattr(factory, "dims"):
dims = factory.dims
if isinstance(dims, str):
dims = (dims,)
if tuple(dims) == tuple(self.dims):
factory.dims = (self.active_dim,)
return factory
[docs]
def create_variable(self, name: str) -> TensorVariable:
"""Create a deterministic variable with full dims using the active subset.
Creates an underlying variable over the active entries only and expands
it back to the full masked shape, filling inactive entries with zeros.
Parameters
----------
name : str
Base name for the created variables.
Returns
-------
pt.TensorVariable
Deterministic variable with the original dims, zeros on inactive entries.
"""
model = pm.modelcontext(None)
flat_mask = self.mask.values.ravel().astype(bool)
n_active = int(flat_mask.sum())
if n_active == 0:
return pm.Deterministic(name, pt.zeros(self.mask.shape), dims=self.dims)
# Ensure the coord exists and has the right length
if (
self.active_dim in model.coords
and len(model.coords[self.active_dim]) != n_active
):
self.active_dim = f"{self.active_dim}__{n_active}"
model.add_coords({self.active_dim: np.arange(n_active)})
# Make a deep copy and remap dims depth-first before creating the RV
reduced = self._remap_dims(self.prior.deepcopy())
active_rv = reduced.create_variable(f"{name}_active") # shape: (active_dim,)
flat_full = pt.zeros((self.mask.size,), dtype=active_rv.dtype)
full = flat_full[flat_mask].set(active_rv).reshape(self.mask.shape)
return pm.Deterministic(name, full, dims=self.dims)
[docs]
def to_dict(self) -> dict[str, Any]:
"""Serialize MaskedPrior to a JSON-serializable dictionary.
Returns
-------
dict
Dictionary containing the prior, mask, and active_dim.
"""
# Store mask as a plain nested list of bools to avoid datetime coords serialization
mask_list = (
self.mask.values.astype(bool).tolist()
if hasattr(self.mask, "values")
else np.asarray(self.mask, dtype=bool).tolist()
)
return {
"class": "MaskedPrior",
"data": {
"prior": self.prior.to_dict()
if hasattr(self.prior, "to_dict")
else None,
"mask": mask_list,
"mask_dims": list(self.dims),
"active_dim": self.active_dim,
},
}
[docs]
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "MaskedPrior":
"""Deserialize MaskedPrior from dictionary created by ``to_dict``.
Parameters
----------
data : dict
Dictionary produced by :meth:`to_dict`.
Returns
-------
MaskedPrior
Reconstructed instance.
"""
payload = data["data"] if "data" in data else data
prior = (
deserialize(payload["prior"])
if isinstance(payload.get("prior"), dict)
else payload.get("prior")
)
mask_vals = payload.get("mask")
# Fallback to provided dims or infer from prior if available
mask_dims = payload.get("mask_dims") or (getattr(prior, "dims", None) or ())
mask_da = xr.DataArray(np.asarray(mask_vals, dtype=bool), dims=tuple(mask_dims))
active_dim = payload.get("active_dim")
return cls(prior=prior, mask=mask_da, active_dim=active_dim)
[docs]
def create_likelihood_variable(
self, name: str, *, mu: pt.TensorLike, observed: pt.TensorLike
) -> TensorVariable:
"""Create an observed variable over the active subset and expand to full dims.
Parameters
----------
name : str
Base name for the created variables.
mu : pt.TensorLike
Mean/location parameter broadcastable to the masked shape.
observed : pt.TensorLike
Observations broadcastable to the masked shape.
Returns
-------
pt.TensorVariable
Deterministic variable over the full dims with observed RV on active entries.
"""
model = pm.modelcontext(None)
flat_mask = self.mask.values.ravel().astype(bool)
n_active = int(flat_mask.sum())
if n_active == 0:
return pm.Deterministic(name, pt.zeros(self.mask.shape), dims=self.dims)
# Ensure the coord exists and has the right length
if (
self.active_dim in model.coords
and len(model.coords[self.active_dim]) != n_active
):
self.active_dim = f"{self.active_dim}__{n_active}"
model.add_coords({self.active_dim: np.arange(n_active)})
# Remap dims on a deep copy so nested parameter priors match the active subset
reduced = self._remap_dims(self.prior.deepcopy())
# Broadcast mu/observed to full mask shape via arithmetic broadcasting, then select active entries
mu_tensor = pt.as_tensor_variable(mu)
mu_full = mu_tensor + pt.zeros(self.mask.shape, dtype=mu_tensor.dtype)
mu_active = mu_full.reshape((self.mask.size,))[flat_mask]
obs = observed.values if hasattr(observed, "values") else observed
obs_tensor = pt.as_tensor_variable(obs)
obs_full = obs_tensor + pt.zeros(self.mask.shape, dtype=obs_tensor.dtype)
obs_active = obs_full.reshape((self.mask.size,))[flat_mask]
# Create the masked observed RV over the active subset
active_name = f"{name}_active"
active_rv = reduced.create_likelihood_variable(
active_name, mu=mu_active, observed=obs_active
)
# Expand back to full shape for user-friendly access
flat_full = pt.zeros((self.mask.size,), dtype=active_rv.dtype)
full = flat_full[flat_mask].set(active_rv).reshape(self.mask.shape)
return pm.Deterministic(name, full, dims=self.dims)
def _is_masked_prior_type(data: dict) -> bool:
return data.keys() == {"class", "data"} and data.get("class") == "MaskedPrior"
register_deserialization(
is_type=_is_masked_prior_type, deserialize=MaskedPrior.from_dict
)