pytransc.utils.types
Custom types for pytransc.
1"""Custom types for pytransc.""" 2 3from typing import Annotated, Protocol, TypeAlias 4 5import numpy as np 6import numpy.typing as npt 7 8# See https://medium.com/data-science-collective/do-more-with-numpy-array-type-hints-annotate-validate-shape-dtype-09f81c496746 9# for guidance on numpy type annotations. 10# These types are not actually supported by type checkers, so this is more for documentation purposes. 11# Current numpy type annotations only specify the dtype, not the shape. 12IntArray: TypeAlias = npt.NDArray[np.integer] 13FloatArray: TypeAlias = npt.NDArray[np.floating] 14MultiWalkerStateChain: TypeAlias = Annotated[IntArray, "(n_walkers, n_steps)"] 15MultiWalkerModelChain: TypeAlias = Annotated[ 16 list[list[FloatArray]], 17 "(n_walkers, n_steps, n_dims[state_i])", 18] 19StateOrderedEnsemble: TypeAlias = list[FloatArray] 20 21 22class MultiStateDensity(Protocol): 23 """Protocol for multi-state density functions. 24 25 This protocol defines the interface for functions that can evaluate 26 log-density values at points in different states. These functions 27 are not necessarily normalized. 28 29 Used by all trans-conceptual samplers for posterior evaluation. 30 """ 31 32 def __call__(self, x: FloatArray, state: int) -> float: 33 """Evaluate the log-density at point x in the given state. 34 35 Parameters 36 ---------- 37 x : FloatArray 38 Input point where the density is evaluated. Shape should match 39 the parameter space dimension for the given state. 40 state : int 41 The state index (0-based) for which the density is evaluated. 42 43 Returns 44 ------- 45 float 46 Log-density value at x in the specified state. 47 """ 48 ... 49 50 51class MultiStateDraw(Protocol): 52 """Protocol for multi-state density functions that support drawing.""" 53 54 def __call__(self, state: int) -> FloatArray: 55 """Draw a random sample from the distribution for the given state. 56 57 Parameters 58 ---------- 59 state : int 60 The state index from which to draw the sample. 61 62 Returns 63 ------- 64 FloatArray 65 A random sample from the distribution in the specified state. 66 Shape should match the parameter space dimension for that state. 67 """ 68 ... 69 70 71class SampleableMultiStateDensity(Protocol): 72 """Protocol for multi-state density functions that support sampling. 73 74 This protocol extends MultiStateDensity to include the ability to 75 generate random samples from the distribution. This is primarily 76 used by the state-jump sampler for drawing from pseudo-priors 77 when proposing between-state moves. 78 """ 79 80 __call__: MultiStateDensity 81 draw_deviate: MultiStateDraw 82 83 84class ProposableMultiStateDensity(MultiStateDensity, Protocol): 85 """Protocol for multi-state density functions that support proposals. 86 87 This protocol extends MultiStateDensity to include the ability to 88 generate proposal moves within a state. Used by the state-jump 89 sampler for within-state moves. 90 """ 91 92 def propose(self, x: FloatArray, state: int) -> FloatArray: 93 """Propose a new point based on the current point x. 94 95 Parameters 96 ---------- 97 x : FloatArray 98 Current point in the parameter space for the given state. 99 state : int 100 The state index for which the proposal is generated. 101 102 Returns 103 ------- 104 FloatArray 105 Proposed new point in the same state. Shape should match 106 the input parameter x. 107 """ 108 ... 109 110 111class MultiStateMultiWalkerResult(Protocol): 112 """Protocol for results from multi-state multi-walker samplers. 113 114 This protocol defines the interface for objects that store the results 115 of trans-conceptual MCMC sampling with multiple walkers. It provides 116 access to state chains, model chains, and visit statistics. 117 """ 118 119 @property 120 def state_chain(self) -> MultiWalkerStateChain: 121 """State visitation sequence for each walker. 122 123 Returns 124 ------- 125 MultiWalkerStateChain 126 Array of shape (n_walkers, n_steps) containing the state 127 index visited by each walker at each step. 128 129 Expected shape is (n_walkers, n_steps). 130 """ 131 ... 132 133 @property 134 def model_chain(self) -> MultiWalkerModelChain: 135 """Model chain for each walker. 136 137 Expected shape is (n_walkers, n_steps, n_dims). 138 """ 139 ... 140 141 @property 142 def state_chain_tot(self) -> IntArray: 143 """Running totals of states visited along the Markov chains.""" 144 ... 145 146 @property 147 def n_walkers(self) -> int: 148 """Number of walkers in the sampler.""" 149 ... 150 151 @property 152 def n_steps(self) -> int: 153 """Number of steps in the sampler.""" 154 ... 155 156 @property 157 def n_states(self) -> int: 158 """Number of states in the sampler.""" 159 ...
23class MultiStateDensity(Protocol): 24 """Protocol for multi-state density functions. 25 26 This protocol defines the interface for functions that can evaluate 27 log-density values at points in different states. These functions 28 are not necessarily normalized. 29 30 Used by all trans-conceptual samplers for posterior evaluation. 31 """ 32 33 def __call__(self, x: FloatArray, state: int) -> float: 34 """Evaluate the log-density at point x in the given state. 35 36 Parameters 37 ---------- 38 x : FloatArray 39 Input point where the density is evaluated. Shape should match 40 the parameter space dimension for the given state. 41 state : int 42 The state index (0-based) for which the density is evaluated. 43 44 Returns 45 ------- 46 float 47 Log-density value at x in the specified state. 48 """ 49 ...
Protocol for multi-state density functions.
This protocol defines the interface for functions that can evaluate log-density values at points in different states. These functions are not necessarily normalized.
Used by all trans-conceptual samplers for posterior evaluation.
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
52class MultiStateDraw(Protocol): 53 """Protocol for multi-state density functions that support drawing.""" 54 55 def __call__(self, state: int) -> FloatArray: 56 """Draw a random sample from the distribution for the given state. 57 58 Parameters 59 ---------- 60 state : int 61 The state index from which to draw the sample. 62 63 Returns 64 ------- 65 FloatArray 66 A random sample from the distribution in the specified state. 67 Shape should match the parameter space dimension for that state. 68 """ 69 ...
Protocol for multi-state density functions that support drawing.
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
72class SampleableMultiStateDensity(Protocol): 73 """Protocol for multi-state density functions that support sampling. 74 75 This protocol extends MultiStateDensity to include the ability to 76 generate random samples from the distribution. This is primarily 77 used by the state-jump sampler for drawing from pseudo-priors 78 when proposing between-state moves. 79 """ 80 81 __call__: MultiStateDensity 82 draw_deviate: MultiStateDraw
Protocol for multi-state density functions that support sampling.
This protocol extends MultiStateDensity to include the ability to generate random samples from the distribution. This is primarily used by the state-jump sampler for drawing from pseudo-priors when proposing between-state moves.
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
85class ProposableMultiStateDensity(MultiStateDensity, Protocol): 86 """Protocol for multi-state density functions that support proposals. 87 88 This protocol extends MultiStateDensity to include the ability to 89 generate proposal moves within a state. Used by the state-jump 90 sampler for within-state moves. 91 """ 92 93 def propose(self, x: FloatArray, state: int) -> FloatArray: 94 """Propose a new point based on the current point x. 95 96 Parameters 97 ---------- 98 x : FloatArray 99 Current point in the parameter space for the given state. 100 state : int 101 The state index for which the proposal is generated. 102 103 Returns 104 ------- 105 FloatArray 106 Proposed new point in the same state. Shape should match 107 the input parameter x. 108 """ 109 ...
Protocol for multi-state density functions that support proposals.
This protocol extends MultiStateDensity to include the ability to generate proposal moves within a state. Used by the state-jump sampler for within-state moves.
93 def propose(self, x: FloatArray, state: int) -> FloatArray: 94 """Propose a new point based on the current point x. 95 96 Parameters 97 ---------- 98 x : FloatArray 99 Current point in the parameter space for the given state. 100 state : int 101 The state index for which the proposal is generated. 102 103 Returns 104 ------- 105 FloatArray 106 Proposed new point in the same state. Shape should match 107 the input parameter x. 108 """ 109 ...
Propose a new point based on the current point x.
Parameters
x : FloatArray Current point in the parameter space for the given state. state : int The state index for which the proposal is generated.
Returns
FloatArray Proposed new point in the same state. Shape should match the input parameter x.
Inherited Members
112class MultiStateMultiWalkerResult(Protocol): 113 """Protocol for results from multi-state multi-walker samplers. 114 115 This protocol defines the interface for objects that store the results 116 of trans-conceptual MCMC sampling with multiple walkers. It provides 117 access to state chains, model chains, and visit statistics. 118 """ 119 120 @property 121 def state_chain(self) -> MultiWalkerStateChain: 122 """State visitation sequence for each walker. 123 124 Returns 125 ------- 126 MultiWalkerStateChain 127 Array of shape (n_walkers, n_steps) containing the state 128 index visited by each walker at each step. 129 130 Expected shape is (n_walkers, n_steps). 131 """ 132 ... 133 134 @property 135 def model_chain(self) -> MultiWalkerModelChain: 136 """Model chain for each walker. 137 138 Expected shape is (n_walkers, n_steps, n_dims). 139 """ 140 ... 141 142 @property 143 def state_chain_tot(self) -> IntArray: 144 """Running totals of states visited along the Markov chains.""" 145 ... 146 147 @property 148 def n_walkers(self) -> int: 149 """Number of walkers in the sampler.""" 150 ... 151 152 @property 153 def n_steps(self) -> int: 154 """Number of steps in the sampler.""" 155 ... 156 157 @property 158 def n_states(self) -> int: 159 """Number of states in the sampler.""" 160 ...
Protocol for results from multi-state multi-walker samplers.
This protocol defines the interface for objects that store the results of trans-conceptual MCMC sampling with multiple walkers. It provides access to state chains, model chains, and visit statistics.
1739def _no_init_or_replace_init(self, *args, **kwargs): 1740 cls = type(self) 1741 1742 if cls._is_protocol: 1743 raise TypeError('Protocols cannot be instantiated') 1744 1745 # Already using a custom `__init__`. No need to calculate correct 1746 # `__init__` to call. This can lead to RecursionError. See bpo-45121. 1747 if cls.__init__ is not _no_init_or_replace_init: 1748 return 1749 1750 # Initially, `__init__` of a protocol subclass is set to `_no_init_or_replace_init`. 1751 # The first instantiation of the subclass will call `_no_init_or_replace_init` which 1752 # searches for a proper new `__init__` in the MRO. The new `__init__` 1753 # replaces the subclass' old `__init__` (ie `_no_init_or_replace_init`). Subsequent 1754 # instantiation of the protocol subclass will thus use the new 1755 # `__init__` and no longer call `_no_init_or_replace_init`. 1756 for base in cls.__mro__: 1757 init = base.__dict__.get('__init__', _no_init_or_replace_init) 1758 if init is not _no_init_or_replace_init: 1759 cls.__init__ = init 1760 break 1761 else: 1762 # should not happen 1763 cls.__init__ = object.__init__ 1764 1765 cls.__init__(self, *args, **kwargs)
120 @property 121 def state_chain(self) -> MultiWalkerStateChain: 122 """State visitation sequence for each walker. 123 124 Returns 125 ------- 126 MultiWalkerStateChain 127 Array of shape (n_walkers, n_steps) containing the state 128 index visited by each walker at each step. 129 130 Expected shape is (n_walkers, n_steps). 131 """ 132 ...
State visitation sequence for each walker.
Returns
MultiWalkerStateChain Array of shape (n_walkers, n_steps) containing the state index visited by each walker at each step.
Expected shape is (n_walkers, n_steps).
134 @property 135 def model_chain(self) -> MultiWalkerModelChain: 136 """Model chain for each walker. 137 138 Expected shape is (n_walkers, n_steps, n_dims). 139 """ 140 ...
Model chain for each walker.
Expected shape is (n_walkers, n_steps, n_dims).
142 @property 143 def state_chain_tot(self) -> IntArray: 144 """Running totals of states visited along the Markov chains.""" 145 ...
Running totals of states visited along the Markov chains.