pytransc.utils.types

Custom types for pytransc.

  1"""Custom types for pytransc."""
  2
  3from typing import Annotated, Protocol, TypeAlias
  4
  5import numpy as np
  6import numpy.typing as npt
  7
  8# See https://medium.com/data-science-collective/do-more-with-numpy-array-type-hints-annotate-validate-shape-dtype-09f81c496746
  9# for guidance on numpy type annotations.
 10# These types are not actually supported by type checkers, so this is more for documentation purposes.
 11# Current numpy type annotations only specify the dtype, not the shape.
 12IntArray: TypeAlias = npt.NDArray[np.integer]
 13FloatArray: TypeAlias = npt.NDArray[np.floating]
 14MultiWalkerStateChain: TypeAlias = Annotated[IntArray, "(n_walkers, n_steps)"]
 15MultiWalkerModelChain: TypeAlias = Annotated[
 16    list[list[FloatArray]],
 17    "(n_walkers, n_steps, n_dims[state_i])",
 18]
 19StateOrderedEnsemble: TypeAlias = list[FloatArray]
 20
 21
 22class MultiStateDensity(Protocol):
 23    """Protocol for multi-state density functions.
 24
 25    This protocol defines the interface for functions that can evaluate
 26    log-density values at points in different states. These functions
 27    are not necessarily normalized.
 28
 29    Used by all trans-conceptual samplers for posterior evaluation.
 30    """
 31
 32    def __call__(self, x: FloatArray, state: int) -> float:
 33        """Evaluate the log-density at point x in the given state.
 34
 35        Parameters
 36        ----------
 37        x : FloatArray
 38            Input point where the density is evaluated. Shape should match
 39            the parameter space dimension for the given state.
 40        state : int
 41            The state index (0-based) for which the density is evaluated.
 42
 43        Returns
 44        -------
 45        float
 46            Log-density value at x in the specified state.
 47        """
 48        ...
 49
 50
 51class MultiStateDraw(Protocol):
 52    """Protocol for multi-state density functions that support drawing."""
 53
 54    def __call__(self, state: int) -> FloatArray:
 55        """Draw a random sample from the distribution for the given state.
 56
 57        Parameters
 58        ----------
 59        state : int
 60            The state index from which to draw the sample.
 61
 62        Returns
 63        -------
 64        FloatArray
 65            A random sample from the distribution in the specified state.
 66            Shape should match the parameter space dimension for that state.
 67        """
 68        ...
 69
 70
 71class SampleableMultiStateDensity(Protocol):
 72    """Protocol for multi-state density functions that support sampling.
 73
 74    This protocol extends MultiStateDensity to include the ability to
 75    generate random samples from the distribution. This is primarily
 76    used by the state-jump sampler for drawing from pseudo-priors
 77    when proposing between-state moves.
 78    """
 79
 80    __call__: MultiStateDensity
 81    draw_deviate: MultiStateDraw
 82
 83
 84class ProposableMultiStateDensity(MultiStateDensity, Protocol):
 85    """Protocol for multi-state density functions that support proposals.
 86
 87    This protocol extends MultiStateDensity to include the ability to
 88    generate proposal moves within a state. Used by the state-jump
 89    sampler for within-state moves.
 90    """
 91
 92    def propose(self, x: FloatArray, state: int) -> FloatArray:
 93        """Propose a new point based on the current point x.
 94
 95        Parameters
 96        ----------
 97        x : FloatArray
 98            Current point in the parameter space for the given state.
 99        state : int
100            The state index for which the proposal is generated.
101
102        Returns
103        -------
104        FloatArray
105            Proposed new point in the same state. Shape should match
106            the input parameter x.
107        """
108        ...
109
110
111class MultiStateMultiWalkerResult(Protocol):
112    """Protocol for results from multi-state multi-walker samplers.
113
114    This protocol defines the interface for objects that store the results
115    of trans-conceptual MCMC sampling with multiple walkers. It provides
116    access to state chains, model chains, and visit statistics.
117    """
118
119    @property
120    def state_chain(self) -> MultiWalkerStateChain:
121        """State visitation sequence for each walker.
122
123        Returns
124        -------
125        MultiWalkerStateChain
126            Array of shape (n_walkers, n_steps) containing the state
127            index visited by each walker at each step.
128
129        Expected shape is (n_walkers, n_steps).
130        """
131        ...
132
133    @property
134    def model_chain(self) -> MultiWalkerModelChain:
135        """Model chain for each walker.
136
137        Expected shape is (n_walkers, n_steps, n_dims).
138        """
139        ...
140
141    @property
142    def state_chain_tot(self) -> IntArray:
143        """Running totals of states visited along the Markov chains."""
144        ...
145
146    @property
147    def n_walkers(self) -> int:
148        """Number of walkers in the sampler."""
149        ...
150
151    @property
152    def n_steps(self) -> int:
153        """Number of steps in the sampler."""
154        ...
155
156    @property
157    def n_states(self) -> int:
158        """Number of states in the sampler."""
159        ...
IntArray: TypeAlias = numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.integer]]
FloatArray: TypeAlias = numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]
MultiWalkerStateChain: TypeAlias = Annotated[numpy.ndarray[tuple[Any, ...], numpy.dtype[numpy.integer]], '(n_walkers, n_steps)']
MultiWalkerModelChain: TypeAlias = Annotated[list[list[numpy.ndarray[tuple[Any, ...], numpy.dtype[numpy.floating]]]], '(n_walkers, n_steps, n_dims[state_i])']
StateOrderedEnsemble: TypeAlias = list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]]
class MultiStateDensity(typing.Protocol):
23class MultiStateDensity(Protocol):
24    """Protocol for multi-state density functions.
25
26    This protocol defines the interface for functions that can evaluate
27    log-density values at points in different states. These functions
28    are not necessarily normalized.
29
30    Used by all trans-conceptual samplers for posterior evaluation.
31    """
32
33    def __call__(self, x: FloatArray, state: int) -> float:
34        """Evaluate the log-density at point x in the given state.
35
36        Parameters
37        ----------
38        x : FloatArray
39            Input point where the density is evaluated. Shape should match
40            the parameter space dimension for the given state.
41        state : int
42            The state index (0-based) for which the density is evaluated.
43
44        Returns
45        -------
46        float
47            Log-density value at x in the specified state.
48        """
49        ...

Protocol for multi-state density functions.

This protocol defines the interface for functions that can evaluate log-density values at points in different states. These functions are not necessarily normalized.

Used by all trans-conceptual samplers for posterior evaluation.

MultiStateDensity(*args, **kwargs)
1739def _no_init_or_replace_init(self, *args, **kwargs):
1740    cls = type(self)
1741
1742    if cls._is_protocol:
1743        raise TypeError('Protocols cannot be instantiated')
1744
1745    # Already using a custom `__init__`. No need to calculate correct
1746    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1747    if cls.__init__ is not _no_init_or_replace_init:
1748        return
1749
1750    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1751    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1752    # searches for a proper new `__init__` in the MRO. The new `__init__`
1753    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1754    # instantiation of the protocol subclass will thus use the new
1755    # `__init__` and no longer call `_no_init_or_replace_init`.
1756    for base in cls.__mro__:
1757        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1758        if init is not _no_init_or_replace_init:
1759            cls.__init__ = init
1760            break
1761    else:
1762        # should not happen
1763        cls.__init__ = object.__init__
1764
1765    cls.__init__(self, *args, **kwargs)
class MultiStateDraw(typing.Protocol):
52class MultiStateDraw(Protocol):
53    """Protocol for multi-state density functions that support drawing."""
54
55    def __call__(self, state: int) -> FloatArray:
56        """Draw a random sample from the distribution for the given state.
57
58        Parameters
59        ----------
60        state : int
61            The state index from which to draw the sample.
62
63        Returns
64        -------
65        FloatArray
66            A random sample from the distribution in the specified state.
67            Shape should match the parameter space dimension for that state.
68        """
69        ...

Protocol for multi-state density functions that support drawing.

MultiStateDraw(*args, **kwargs)
1739def _no_init_or_replace_init(self, *args, **kwargs):
1740    cls = type(self)
1741
1742    if cls._is_protocol:
1743        raise TypeError('Protocols cannot be instantiated')
1744
1745    # Already using a custom `__init__`. No need to calculate correct
1746    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1747    if cls.__init__ is not _no_init_or_replace_init:
1748        return
1749
1750    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1751    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1752    # searches for a proper new `__init__` in the MRO. The new `__init__`
1753    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1754    # instantiation of the protocol subclass will thus use the new
1755    # `__init__` and no longer call `_no_init_or_replace_init`.
1756    for base in cls.__mro__:
1757        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1758        if init is not _no_init_or_replace_init:
1759            cls.__init__ = init
1760            break
1761    else:
1762        # should not happen
1763        cls.__init__ = object.__init__
1764
1765    cls.__init__(self, *args, **kwargs)
class SampleableMultiStateDensity(typing.Protocol):
72class SampleableMultiStateDensity(Protocol):
73    """Protocol for multi-state density functions that support sampling.
74
75    This protocol extends MultiStateDensity to include the ability to
76    generate random samples from the distribution. This is primarily
77    used by the state-jump sampler for drawing from pseudo-priors
78    when proposing between-state moves.
79    """
80
81    __call__: MultiStateDensity
82    draw_deviate: MultiStateDraw

Protocol for multi-state density functions that support sampling.

This protocol extends MultiStateDensity to include the ability to generate random samples from the distribution. This is primarily used by the state-jump sampler for drawing from pseudo-priors when proposing between-state moves.

SampleableMultiStateDensity(*args, **kwargs)
1739def _no_init_or_replace_init(self, *args, **kwargs):
1740    cls = type(self)
1741
1742    if cls._is_protocol:
1743        raise TypeError('Protocols cannot be instantiated')
1744
1745    # Already using a custom `__init__`. No need to calculate correct
1746    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1747    if cls.__init__ is not _no_init_or_replace_init:
1748        return
1749
1750    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1751    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1752    # searches for a proper new `__init__` in the MRO. The new `__init__`
1753    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1754    # instantiation of the protocol subclass will thus use the new
1755    # `__init__` and no longer call `_no_init_or_replace_init`.
1756    for base in cls.__mro__:
1757        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1758        if init is not _no_init_or_replace_init:
1759            cls.__init__ = init
1760            break
1761    else:
1762        # should not happen
1763        cls.__init__ = object.__init__
1764
1765    cls.__init__(self, *args, **kwargs)
draw_deviate: MultiStateDraw
class ProposableMultiStateDensity(MultiStateDensity, typing.Protocol):
 85class ProposableMultiStateDensity(MultiStateDensity, Protocol):
 86    """Protocol for multi-state density functions that support proposals.
 87
 88    This protocol extends MultiStateDensity to include the ability to
 89    generate proposal moves within a state. Used by the state-jump
 90    sampler for within-state moves.
 91    """
 92
 93    def propose(self, x: FloatArray, state: int) -> FloatArray:
 94        """Propose a new point based on the current point x.
 95
 96        Parameters
 97        ----------
 98        x : FloatArray
 99            Current point in the parameter space for the given state.
100        state : int
101            The state index for which the proposal is generated.
102
103        Returns
104        -------
105        FloatArray
106            Proposed new point in the same state. Shape should match
107            the input parameter x.
108        """
109        ...

Protocol for multi-state density functions that support proposals.

This protocol extends MultiStateDensity to include the ability to generate proposal moves within a state. Used by the state-jump sampler for within-state moves.

def propose( self, x: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]], state: int) -> numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]:
 93    def propose(self, x: FloatArray, state: int) -> FloatArray:
 94        """Propose a new point based on the current point x.
 95
 96        Parameters
 97        ----------
 98        x : FloatArray
 99            Current point in the parameter space for the given state.
100        state : int
101            The state index for which the proposal is generated.
102
103        Returns
104        -------
105        FloatArray
106            Proposed new point in the same state. Shape should match
107            the input parameter x.
108        """
109        ...

Propose a new point based on the current point x.

Parameters

x : FloatArray Current point in the parameter space for the given state. state : int The state index for which the proposal is generated.

Returns

FloatArray Proposed new point in the same state. Shape should match the input parameter x.

class MultiStateMultiWalkerResult(typing.Protocol):
112class MultiStateMultiWalkerResult(Protocol):
113    """Protocol for results from multi-state multi-walker samplers.
114
115    This protocol defines the interface for objects that store the results
116    of trans-conceptual MCMC sampling with multiple walkers. It provides
117    access to state chains, model chains, and visit statistics.
118    """
119
120    @property
121    def state_chain(self) -> MultiWalkerStateChain:
122        """State visitation sequence for each walker.
123
124        Returns
125        -------
126        MultiWalkerStateChain
127            Array of shape (n_walkers, n_steps) containing the state
128            index visited by each walker at each step.
129
130        Expected shape is (n_walkers, n_steps).
131        """
132        ...
133
134    @property
135    def model_chain(self) -> MultiWalkerModelChain:
136        """Model chain for each walker.
137
138        Expected shape is (n_walkers, n_steps, n_dims).
139        """
140        ...
141
142    @property
143    def state_chain_tot(self) -> IntArray:
144        """Running totals of states visited along the Markov chains."""
145        ...
146
147    @property
148    def n_walkers(self) -> int:
149        """Number of walkers in the sampler."""
150        ...
151
152    @property
153    def n_steps(self) -> int:
154        """Number of steps in the sampler."""
155        ...
156
157    @property
158    def n_states(self) -> int:
159        """Number of states in the sampler."""
160        ...

Protocol for results from multi-state multi-walker samplers.

This protocol defines the interface for objects that store the results of trans-conceptual MCMC sampling with multiple walkers. It provides access to state chains, model chains, and visit statistics.

MultiStateMultiWalkerResult(*args, **kwargs)
1739def _no_init_or_replace_init(self, *args, **kwargs):
1740    cls = type(self)
1741
1742    if cls._is_protocol:
1743        raise TypeError('Protocols cannot be instantiated')
1744
1745    # Already using a custom `__init__`. No need to calculate correct
1746    # `__init__` to call. This can lead to RecursionError. See bpo-45121.
1747    if cls.__init__ is not _no_init_or_replace_init:
1748        return
1749
1750    # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`.
1751    # The first instantiation of the subclass will call `_no_init_or_replace_init` which
1752    # searches for a proper new `__init__` in the MRO. The new `__init__`
1753    # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent
1754    # instantiation of the protocol subclass will thus use the new
1755    # `__init__` and no longer call `_no_init_or_replace_init`.
1756    for base in cls.__mro__:
1757        init = base.__dict__.get('__init__', _no_init_or_replace_init)
1758        if init is not _no_init_or_replace_init:
1759            cls.__init__ = init
1760            break
1761    else:
1762        # should not happen
1763        cls.__init__ = object.__init__
1764
1765    cls.__init__(self, *args, **kwargs)
state_chain: Annotated[numpy.ndarray[tuple[Any, ...], numpy.dtype[numpy.integer]], '(n_walkers, n_steps)']
120    @property
121    def state_chain(self) -> MultiWalkerStateChain:
122        """State visitation sequence for each walker.
123
124        Returns
125        -------
126        MultiWalkerStateChain
127            Array of shape (n_walkers, n_steps) containing the state
128            index visited by each walker at each step.
129
130        Expected shape is (n_walkers, n_steps).
131        """
132        ...

State visitation sequence for each walker.

Returns

MultiWalkerStateChain Array of shape (n_walkers, n_steps) containing the state index visited by each walker at each step.

Expected shape is (n_walkers, n_steps).

model_chain: Annotated[list[list[numpy.ndarray[tuple[Any, ...], numpy.dtype[numpy.floating]]]], '(n_walkers, n_steps, n_dims[state_i])']
134    @property
135    def model_chain(self) -> MultiWalkerModelChain:
136        """Model chain for each walker.
137
138        Expected shape is (n_walkers, n_steps, n_dims).
139        """
140        ...

Model chain for each walker.

Expected shape is (n_walkers, n_steps, n_dims).

state_chain_tot: numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.integer]]
142    @property
143    def state_chain_tot(self) -> IntArray:
144        """Running totals of states visited along the Markov chains."""
145        ...

Running totals of states visited along the Markov chains.

n_walkers: int
147    @property
148    def n_walkers(self) -> int:
149        """Number of walkers in the sampler."""
150        ...

Number of walkers in the sampler.

n_steps: int
152    @property
153    def n_steps(self) -> int:
154        """Number of steps in the sampler."""
155        ...

Number of steps in the sampler.

n_states: int
157    @property
158    def n_states(self) -> int:
159        """Number of states in the sampler."""
160        ...

Number of states in the sampler.