pytransc.samplers
Sampling algorithms for pyTransC.
This module provides the main sampling algorithms for trans-conceptual MCMC:
- Product space sampling: Fixed-dimensional sampling over the product space
- State-jump sampling: Direct jumping between different conceptual states
- Ensemble resampling: Resampling from pre-computed posterior ensembles
- Per-state MCMC: Independent sampling within each state
Each sampler has different advantages and use cases depending on the problem structure and computational requirements.
1"""Sampling algorithms for pyTransC. 2 3This module provides the main sampling algorithms for trans-conceptual MCMC: 4 5- Product space sampling: Fixed-dimensional sampling over the product space 6- State-jump sampling: Direct jumping between different conceptual states 7- Ensemble resampling: Resampling from pre-computed posterior ensembles 8- Per-state MCMC: Independent sampling within each state 9 10Each sampler has different advantages and use cases depending on the problem 11structure and computational requirements. 12""" 13 14from .ensemble_resampler import run_ensemble_resampler 15from .per_state import run_mcmc_per_state 16from .product_space import run_product_space_sampler 17from .state_jump import run_state_jump_sampler 18 19__all__ = [ 20 "run_product_space_sampler", 21 "run_ensemble_resampler", 22 "run_mcmc_per_state", 23 "run_state_jump_sampler", 24]
215def run_product_space_sampler( 216 product_space: ProductSpace, 217 n_walkers: int, 218 n_steps: int, 219 start_positions: list[FloatArray], 220 start_states: list[int], 221 log_posterior: MultiStateDensity, 222 log_pseudo_prior: MultiStateDensity, 223 seed: int | None = 61254557, 224 progress: bool = False, 225 pool: Any | None = None, 226 forward_pool: Any | None = None, 227 **kwargs, 228) -> MultiWalkerProductSpaceChain: 229 """Run MCMC sampler over independent states using emcee in trans-C product space. 230 231 This function implements trans-conceptual MCMC sampling by embedding all states 232 in a fixed-dimensional product space. The sampler uses the emcee ensemble sampler 233 to explore the combined parameter space of all states. 234 235 Parameters 236 ---------- 237 product_space : ProductSpace 238 The product space definition containing state dimensions. 239 n_walkers : int 240 Number of random walkers used by the product space sampler. 241 n_steps : int 242 Number of MCMC steps required per walker. 243 start_positions : list of FloatArray 244 Starting positions for walkers, one array per walker containing the 245 initial parameter values for the starting state. 246 start_states : list of int 247 Starting state indices for each walker. 248 log_posterior : MultiStateDensity 249 Function to evaluate the log-posterior density at location x in state i. 250 Must have signature log_posterior(x, state) -> float. 251 log_pseudo_prior : MultiStateDensity 252 Function to evaluate the log-pseudo-prior density at location x in state i. 253 Must have signature log_pseudo_prior(x, state) -> float. 254 Note: Must be normalized over respective state spaces. 255 seed : int, optional 256 Random number seed for reproducible results. Default is 61254557. 257 progress : bool, optional 258 Whether to display progress information. Default is False. 259 pool : Any | None, optional 260 User-provided pool for parallel processing. The pool must implement 261 a map() method compatible with the standard library's map() function. 262 Default is None. 263 forward_pool : Any | None, optional 264 User-provided pool for parallelizing forward solver calls within 265 log_posterior evaluations. If provided, the pool will be made available 266 to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. 267 The pool must implement a map() method compatible with the standard library's 268 map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 269 and schwimmbad pools. Default is None. 270 **kwargs 271 Additional keyword arguments passed to the emcee sampler. 272 273 Returns 274 ------- 275 MultiWalkerProductSpaceChain 276 Chain results containing state sequences, model parameters, and diagnostics 277 for all walkers. 278 279 Notes 280 ----- 281 The product space approach embeds all possible states in a single fixed-dimensional 282 space. This allows the use of efficient ensemble samplers like emcee, but requires 283 sampling in a higher-dimensional space than any individual state. 284 285 Examples 286 -------- 287 Basic usage: 288 289 >>> ps = ProductSpace(n_dims=[2, 3, 1]) 290 >>> results = run_product_space_sampler( 291 ... product_space=ps, 292 ... n_walkers=32, 293 ... n_steps=1000, 294 ... start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]], 295 ... start_states=[0, 1, 2], 296 ... log_posterior=my_log_posterior, 297 ... log_pseudo_prior=my_log_pseudo_prior 298 ... ) 299 300 Using with schwimmbad pools: 301 302 >>> from schwimmbad import MPIPool 303 >>> with MPIPool() as pool: 304 ... results = run_product_space_sampler( 305 ... product_space=ps, 306 ... n_walkers=32, 307 ... n_steps=1000, 308 ... start_positions=start_pos, 309 ... start_states=start_states, 310 ... log_posterior=my_log_posterior, 311 ... log_pseudo_prior=my_log_pseudo_prior, 312 ... pool=pool 313 ... ) 314 315 Using with forward pool for parallel forward solver calls: 316 317 >>> from concurrent.futures import ProcessPoolExecutor 318 >>> with ProcessPoolExecutor(max_workers=4) as forward_pool: 319 ... results = run_product_space_sampler( 320 ... product_space=ps, 321 ... n_walkers=32, 322 ... n_steps=1000, 323 ... start_positions=start_pos, 324 ... start_states=start_states, 325 ... log_posterior=my_log_posterior, 326 ... log_pseudo_prior=my_log_pseudo_prior, 327 ... forward_pool=forward_pool 328 ... ) 329 """ 330 331 random.seed(seed) 332 333 # Early validation of forward pool if provided 334 if forward_pool is not None: 335 from ..utils.forward_context import set_forward_pool, clear_forward_pool 336 set_forward_pool(forward_pool) # Validates map() method 337 clear_forward_pool() # Clear after validation 338 339 if progress: 340 print("\nRunning product space trans-C sampler") 341 print("\nNumber of walkers : ", n_walkers) 342 print("Number of states being sampled : ", product_space.n_states) 343 print("Dimensions of each state : ", product_space.n_dims) 344 345 pos_ps = _get_initial_product_space_positions( 346 n_walkers, start_states, start_positions, product_space 347 ) 348 349 log_func = partial( 350 product_space_log_prob, 351 product_space=product_space, 352 log_posterior=log_posterior, 353 log_pseudo_prior=log_pseudo_prior, 354 forward_pool=forward_pool, 355 ) 356 357 sampler = perform_sampling_with_emcee( 358 log_prob_func=log_func, 359 n_walkers=n_walkers, 360 n_steps=n_steps, 361 initial_state=pos_ps, 362 pool=pool, 363 progress=progress, 364 **kwargs, 365 ) 366 367 return MultiWalkerProductSpaceChain.from_emcee(sampler, product_space.n_dims)
Run MCMC sampler over independent states using emcee in trans-C product space.
This function implements trans-conceptual MCMC sampling by embedding all states in a fixed-dimensional product space. The sampler uses the emcee ensemble sampler to explore the combined parameter space of all states.
Parameters
product_space : ProductSpace The product space definition containing state dimensions. n_walkers : int Number of random walkers used by the product space sampler. n_steps : int Number of MCMC steps required per walker. start_positions : list of FloatArray Starting positions for walkers, one array per walker containing the initial parameter values for the starting state. start_states : list of int Starting state indices for each walker. log_posterior : MultiStateDensity Function to evaluate the log-posterior density at location x in state i. Must have signature log_posterior(x, state) -> float. log_pseudo_prior : MultiStateDensity Function to evaluate the log-pseudo-prior density at location x in state i. Must have signature log_pseudo_prior(x, state) -> float. Note: Must be normalized over respective state spaces. seed : int, optional Random number seed for reproducible results. Default is 61254557. progress : bool, optional Whether to display progress information. Default is False. pool : Any | None, optional User-provided pool for parallel processing. The pool must implement a map() method compatible with the standard library's map() function. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None. **kwargs Additional keyword arguments passed to the emcee sampler.
Returns
MultiWalkerProductSpaceChain Chain results containing state sequences, model parameters, and diagnostics for all walkers.
Notes
The product space approach embeds all possible states in a single fixed-dimensional space. This allows the use of efficient ensemble samplers like emcee, but requires sampling in a higher-dimensional space than any individual state.
Examples
Basic usage:
>>> ps = ProductSpace(n_dims=[2, 3, 1])
>>> results = run_product_space_sampler(
... product_space=ps,
... n_walkers=32,
... n_steps=1000,
... start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]],
... start_states=[0, 1, 2],
... log_posterior=my_log_posterior,
... log_pseudo_prior=my_log_pseudo_prior
... )
Using with schwimmbad pools:
>>> from schwimmbad import MPIPool
>>> with MPIPool() as pool:
... results = run_product_space_sampler(
... product_space=ps,
... n_walkers=32,
... n_steps=1000,
... start_positions=start_pos,
... start_states=start_states,
... log_posterior=my_log_posterior,
... log_pseudo_prior=my_log_pseudo_prior,
... pool=pool
... )
Using with forward pool for parallel forward solver calls:
>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
... results = run_product_space_sampler(
... product_space=ps,
... n_walkers=32,
... n_steps=1000,
... start_positions=start_pos,
... start_states=start_states,
... log_posterior=my_log_posterior,
... log_pseudo_prior=my_log_pseudo_prior,
... forward_pool=forward_pool
... )
157def run_ensemble_resampler( # Independent state Marginal Likelihoods from pre-computed posterior and pseudo prior ensembles 158 n_walkers, 159 n_steps, 160 n_states: int, 161 n_dims: list[int], 162 log_posterior_ens: StateOrderedEnsemble, 163 log_pseudo_prior_ens: StateOrderedEnsemble, 164 seed=61254557, 165 state_proposal_weights: list[list[float]] | None = None, 166 progress=False, 167 walker_pool=None, 168 state_pool=None, 169 forward_pool=None, 170) -> MultiWalkerEnsembleResamplerChain: 171 """Run MCMC sampler over independent states using pre-computed ensembles. 172 173 This function performs trans-conceptual MCMC by resampling from pre-computed 174 posterior ensembles in each state. It calculates relative evidence of each state 175 by sampling over the ensemble members according to their posterior and pseudo-prior 176 densities. 177 178 Parameters 179 ---------- 180 n_walkers : int 181 Number of random walkers used by the ensemble resampler. 182 n_steps : int 183 Number of Markov chain steps to perform per walker. 184 n_states : int 185 Number of independent states in the problem. 186 n_dims : list of int 187 List of parameter dimensions for each state. 188 log_posterior_ens : StateOrderedEnsemble 189 Log-posterior values of ensemble members in each state. 190 Format: list of arrays, where each array contains log-posterior values 191 for the ensemble members in that state. 192 log_pseudo_prior_ens : StateOrderedEnsemble 193 Log-pseudo-prior values of ensemble members in each state. 194 Format: list of arrays, where each array contains log-pseudo-prior values 195 for the ensemble members in that state. 196 seed : int, optional 197 Random number seed for reproducible results. Default is 61254557. 198 state_proposal_weights : list of list of float, optional 199 Weights for proposing transitions between states. Should be a matrix 200 where element [i][j] is the weight for proposing state j from state i. 201 Diagonal elements are ignored. If None, uniform weights are used. 202 progress : bool, optional 203 Whether to display progress information. Default is False. 204 walker_pool : Any | None, optional 205 User-provided pool for parallelizing walker execution. The pool must 206 implement a map() method compatible with the standard library's map() 207 function. Default is None. 208 state_pool : Any | None, optional 209 User-provided pool for parallelizing state-level operations such as 210 pseudo-prior evaluation across states. Currently reserved for future 211 enhancements. Default is None. 212 forward_pool : Any | None, optional 213 User-provided pool for parallelizing forward solver calls within 214 log_posterior evaluations. If provided, the pool will be made available 215 to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. 216 The pool must implement a map() method compatible with the standard library's 217 map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 218 and schwimmbad pools. Default is None. 219 220 Returns 221 ------- 222 MultiWalkerEnsembleResamplerChain 223 Chain results containing state sequences, ensemble member indices, 224 and diagnostics for all walkers. 225 226 Notes 227 ----- 228 This method requires pre-computed posterior ensembles and their corresponding 229 log-density values. The ensembles can be generated using `run_mcmc_per_state()` 230 and the pseudo-prior values using automatic fitting routines. 231 232 The algorithm works by: 233 1. Selecting ensemble members within states based on posterior weights 234 2. Proposing transitions between states based on relative evidence 235 3. Accepting/rejecting proposals using Metropolis-Hastings criterion 236 237 Examples 238 -------- 239 >>> results = run_ensemble_resampler( 240 ... n_walkers=32, 241 ... n_steps=1000, 242 ... n_states=3, 243 ... n_dims=[2, 3, 1], 244 ... log_posterior_ens=posterior_ensembles, 245 ... log_pseudo_prior_ens=pseudo_prior_ensembles 246 ... ) 247 248 Using with forward pool for parallel forward solver calls: 249 250 >>> from concurrent.futures import ProcessPoolExecutor 251 >>> with ProcessPoolExecutor(max_workers=4) as forward_pool: 252 ... results = run_ensemble_resampler( 253 ... n_walkers=32, 254 ... n_steps=1000, 255 ... n_states=3, 256 ... n_dims=[2, 3, 1], 257 ... log_posterior_ens=posterior_ensembles, 258 ... log_pseudo_prior_ens=pseudo_prior_ensembles, 259 ... forward_pool=forward_pool 260 ... ) 261 """ 262 263 n_samples = [len(log_post_ens) for log_post_ens in log_posterior_ens] 264 265 # Early validation of forward pool if provided 266 if forward_pool is not None: 267 from ..utils.forward_context import set_forward_pool, clear_forward_pool 268 set_forward_pool(forward_pool) # Validates map() method 269 clear_forward_pool() # Clear after validation 270 271 if state_proposal_weights is None: 272 # uniform proposal weights 273 _state_proposal_weights = [[1.0] * n_states] * n_states 274 else: 275 _state_proposal_weights = np.array(state_proposal_weights) 276 np.fill_diagonal(_state_proposal_weights, 0.0) # ensure diagonal is zero 277 _state_proposal_weights = _state_proposal_weights / _state_proposal_weights.sum( 278 axis=1, keepdims=True 279 ) # set row sums to unity 280 _state_proposal_weights = _state_proposal_weights.tolist() 281 282 logger.info("\nRunning ensemble resampler") 283 logger.info("\nNumber of walkers : %d", n_walkers) 284 logger.info("Number of states being sampled : %d", n_states) 285 logger.info("Dimensions of each state : %s", n_dims) 286 287 random.seed(seed) 288 if walker_pool is not None: 289 chains = _run_mcmc_walker_parallel( 290 n_walkers, 291 n_states, 292 n_samples, 293 n_steps, 294 log_posterior_ens, 295 log_pseudo_prior_ens, 296 state_proposal_weights=_state_proposal_weights, 297 progress=progress, 298 walker_pool=walker_pool, 299 forward_pool=forward_pool, 300 ) 301 302 else: 303 chains = _run_mcmc_walker_serial( 304 n_walkers, 305 n_states, 306 n_samples, 307 n_steps, 308 log_posterior_ens, 309 log_pseudo_prior_ens, 310 state_proposal_weights=_state_proposal_weights, 311 progress=progress, 312 forward_pool=forward_pool, 313 ) 314 315 return MultiWalkerEnsembleResamplerChain(chains)
Run MCMC sampler over independent states using pre-computed ensembles.
This function performs trans-conceptual MCMC by resampling from pre-computed posterior ensembles in each state. It calculates relative evidence of each state by sampling over the ensemble members according to their posterior and pseudo-prior densities.
Parameters
n_walkers : int Number of random walkers used by the ensemble resampler. n_steps : int Number of Markov chain steps to perform per walker. n_states : int Number of independent states in the problem. n_dims : list of int List of parameter dimensions for each state. log_posterior_ens : StateOrderedEnsemble Log-posterior values of ensemble members in each state. Format: list of arrays, where each array contains log-posterior values for the ensemble members in that state. log_pseudo_prior_ens : StateOrderedEnsemble Log-pseudo-prior values of ensemble members in each state. Format: list of arrays, where each array contains log-pseudo-prior values for the ensemble members in that state. seed : int, optional Random number seed for reproducible results. Default is 61254557. state_proposal_weights : list of list of float, optional Weights for proposing transitions between states. Should be a matrix where element [i][j] is the weight for proposing state j from state i. Diagonal elements are ignored. If None, uniform weights are used. progress : bool, optional Whether to display progress information. Default is False. walker_pool : Any | None, optional User-provided pool for parallelizing walker execution. The pool must implement a map() method compatible with the standard library's map() function. Default is None. state_pool : Any | None, optional User-provided pool for parallelizing state-level operations such as pseudo-prior evaluation across states. Currently reserved for future enhancements. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None.
Returns
MultiWalkerEnsembleResamplerChain Chain results containing state sequences, ensemble member indices, and diagnostics for all walkers.
Notes
This method requires pre-computed posterior ensembles and their corresponding
log-density values. The ensembles can be generated using run_mcmc_per_state()
and the pseudo-prior values using automatic fitting routines.
The algorithm works by:
- Selecting ensemble members within states based on posterior weights
- Proposing transitions between states based on relative evidence
- Accepting/rejecting proposals using Metropolis-Hastings criterion
Examples
>>> results = run_ensemble_resampler(
... n_walkers=32,
... n_steps=1000,
... n_states=3,
... n_dims=[2, 3, 1],
... log_posterior_ens=posterior_ensembles,
... log_pseudo_prior_ens=pseudo_prior_ensembles
... )
Using with forward pool for parallel forward solver calls:
>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
... results = run_ensemble_resampler(
... n_walkers=32,
... n_steps=1000,
... n_states=3,
... n_dims=[2, 3, 1],
... log_posterior_ens=posterior_ensembles,
... log_pseudo_prior_ens=pseudo_prior_ensembles,
... forward_pool=forward_pool
... )
97def run_mcmc_per_state( 98 n_states: int, 99 n_dims: list[int], 100 n_walkers: int | list[int], 101 n_steps: int | list[int], 102 pos: list[FloatArray], 103 log_posterior: MultiStateDensity, 104 discard: int | list[int] = 0, 105 thin: int | list[int] = 1, 106 auto_thin: bool = False, 107 seed: int = 61254557, 108 state_pool: Any | None = None, 109 emcee_pool: Any | None = None, 110 n_state_processors: int | None = None, 111 skip_initial_state_check: bool = False, 112 verbose: bool = True, 113 forward_pool: Any | None = None, 114 **kwargs, 115) -> tuple[list[FloatArray], list[FloatArray]]: 116 """Run independent MCMC sampling within each state. 117 118 This utility function runs the emcee sampler independently within each state 119 to generate posterior ensembles. These ensembles can then be used as input 120 for ensemble resampling or for constructing pseudo-priors. 121 122 Parameters 123 ---------- 124 n_states : int 125 Number of independent states in the problem. 126 n_dims : list of int 127 List of parameter dimensions for each state. 128 n_walkers : int or list of int 129 Number of random walkers for the emcee sampler. If int, same number 130 is used for all states. If list, specifies walkers per state. 131 n_steps : int or list of int 132 Number of MCMC steps per walker. If int, same number is used for all 133 states. If list, specifies steps per state. 134 pos : list of FloatArray 135 Starting positions for each state. Each array should have shape 136 (n_walkers[state], n_dims[state]). 137 log_posterior : MultiStateDensity 138 Function to evaluate the log-posterior density at location x in state i. 139 Must have signature log_posterior(x, state) -> float. 140 discard : int or list of int, optional 141 Number of samples to discard as burn-in. If int, same value used for 142 all states. Default is 0. 143 thin : int or list of int, optional 144 Thinning factor for chains. If int, same value used for all states. 145 Default is 1 (no thinning). 146 auto_thin : bool, optional 147 If True, automatically thin chains based on autocorrelation time, 148 ignoring the `thin` parameter. Default is False. 149 seed : int, optional 150 Random number seed for reproducible results. Default is 61254557. 151 skip_initial_state_check : bool, optional 152 Whether to skip emcee's initial state check. Default is False. 153 verbose : bool, optional 154 Whether to print progress information. Default is True. 155 state_pool : Any | None, optional 156 User-provided pool for parallelizing state execution. If provided, states 157 are processed in parallel. The pool must implement a map() method compatible 158 with the standard library's map() function. Default is None. 159 emcee_pool : Any | None, optional 160 User-provided pool for parallelizing emcee walker execution within each state. 161 The pool must implement a map() method compatible with the standard library's 162 map() function. Default is None. 163 forward_pool : Any | None, optional 164 User-provided pool for parallelizing forward solver calls within 165 log_posterior evaluations. If provided, the pool will be made available 166 to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. 167 The pool must implement a map() method compatible with the standard library's 168 map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 169 and schwimmbad pools. Default is None. 170 **kwargs 171 Additional keyword arguments passed to emcee.EnsembleSampler. 172 173 Returns 174 ------- 175 ------- 176 ensemble_per_state : list of FloatArray 177 Posterior samples for each state. Each array has shape 178 (n_samples, n_dims[state]). 179 log_posterior_ens : list of FloatArray 180 Log posterior values for each ensemble. Each array has shape (n_samples,) 181 182 Notes 183 ----- 184 This function is primarily a convenience wrapper around emcee for generating 185 posterior ensembles within each state independently. The resulting ensembles 186 can be used with: 187 188 - `run_ensemble_resampler()` for ensemble-based trans-dimensional sampling 189 - Automatic pseudo-prior construction functions 190 - Direct analysis of within-state posterior distributions 191 192 If `auto_thin=True`, the function will automatically determine appropriate 193 burn-in and thinning based on the autocorrelation time, following emcee 194 best practices. 195 196 Examples 197 -------- 198 Basic usage: 199 200 >>> ensembles, log_probs = run_mcmc_per_state( 201 ... n_states=2, 202 ... n_dims=[3, 2], 203 ... n_walkers=32, 204 ... n_steps=1000, 205 ... pos=[np.random.randn(32, 3), np.random.randn(32, 2)], 206 ... log_posterior=my_log_posterior, 207 ... auto_thin=True 208 ... ) 209 210 Using with state-level parallelism: 211 212 >>> from concurrent.futures import ProcessPoolExecutor 213 >>> with ProcessPoolExecutor(max_workers=4) as state_pool: 214 ... ensembles, log_probs = run_mcmc_per_state( 215 ... n_states=4, 216 ... n_dims=[3, 2, 4, 1], 217 ... n_walkers=32, 218 ... n_steps=1000, 219 ... pos=initial_positions, 220 ... log_posterior=my_log_posterior, 221 ... state_pool=state_pool 222 ... ) 223 224 Using with both state and emcee parallelism: 225 226 >>> from schwimmbad import MPIPool 227 >>> with MPIPool() as state_pool, ProcessPoolExecutor() as emcee_pool: 228 ... ensembles, log_probs = run_mcmc_per_state( 229 ... n_states=2, 230 ... n_dims=[3, 2], 231 ... n_walkers=32, 232 ... n_steps=1000, 233 ... pos=initial_positions, 234 ... log_posterior=my_log_posterior, 235 ... state_pool=state_pool, 236 ... emcee_pool=emcee_pool 237 ... ) 238 239 Using with forward pool for parallel forward solver calls: 240 241 >>> from concurrent.futures import ProcessPoolExecutor 242 >>> with ProcessPoolExecutor(max_workers=4) as forward_pool: 243 ... ensembles, log_probs = run_mcmc_per_state( 244 ... n_states=3, 245 ... n_dims=[2, 3, 1], 246 ... n_walkers=32, 247 ... n_steps=1000, 248 ... pos=initial_positions, 249 ... log_posterior=my_log_posterior, 250 ... forward_pool=forward_pool 251 ... ) 252 """ 253 254 random.seed(seed) 255 256 # Early validation of forward pool if provided 257 if forward_pool is not None: 258 from ..utils.forward_context import set_forward_pool, clear_forward_pool 259 set_forward_pool(forward_pool) # Validates map() method 260 clear_forward_pool() # Clear after validation 261 262 if not isinstance(n_walkers, list): 263 n_walkers = [n_walkers] * n_states 264 if not isinstance(discard, list): 265 discard = [discard] * n_states 266 if not isinstance(thin, list): 267 thin = [thin] * n_states 268 if not isinstance(n_steps, list): 269 n_steps = [n_steps] * n_states 270 271 if auto_thin: 272 # ignore thinning factor because we are post thinning by the auto-correlation times 273 # burn_in is also calculated from the auto-correlation time 274 thin = [1] * n_states 275 discard = [0] * n_states 276 277 if verbose: 278 print("\nRunning within-state sampler separately on each state") 279 print("\nNumber of walkers : ", n_walkers) 280 print("\nNumber of states being sampled: ", n_states) 281 print("Dimensions of each state: ", n_dims) 282 if state_pool is not None: 283 print("Using state-level parallelism") 284 if emcee_pool is not None: 285 print("Using walker-level parallelism") 286 287 # Prepare emcee pool configuration to avoid pickling issues 288 emcee_pool_config = None 289 if emcee_pool is not None: 290 # Determine pool type and configuration 291 if hasattr(emcee_pool, '__class__'): 292 pool_class_name = emcee_pool.__class__.__name__ 293 if pool_class_name == 'ProcessPoolExecutor': 294 emcee_pool_config = { 295 'type': 'ProcessPoolExecutor', 296 'kwargs': {'max_workers': emcee_pool._max_workers} 297 } 298 elif pool_class_name == 'ThreadPoolExecutor': 299 emcee_pool_config = { 300 'type': 'ThreadPoolExecutor', 301 'kwargs': {'max_workers': emcee_pool._max_workers} 302 } 303 304 # Prepare forward pool configuration to avoid pickling issues 305 forward_pool_config = None 306 if forward_pool is not None and state_pool is not None: 307 # Only create config when using state-level parallelism 308 if hasattr(forward_pool, '__class__'): 309 pool_class_name = forward_pool.__class__.__name__ 310 if pool_class_name == 'ProcessPoolExecutor': 311 forward_pool_config = { 312 'type': 'ProcessPoolExecutor', 313 'kwargs': {'max_workers': forward_pool._max_workers} 314 } 315 elif pool_class_name == 'ThreadPoolExecutor': 316 forward_pool_config = { 317 'type': 'ThreadPoolExecutor', 318 'kwargs': {'max_workers': forward_pool._max_workers} 319 } 320 321 # Prepare state processing arguments 322 state_args = [] 323 for i in range(n_states): 324 args_dict = { 325 'state_idx': i, 326 'log_posterior': log_posterior, 327 'n_walkers': n_walkers[i], 328 'pos': pos[i], 329 'n_steps': n_steps[i], 330 'discard': discard[i], 331 'thin': thin[i], 332 'skip_initial_state_check': skip_initial_state_check, 333 'verbose': verbose, 334 **kwargs 335 } 336 337 # Add pool configs for state-level parallelism 338 if state_pool is not None: 339 args_dict['emcee_pool_config'] = emcee_pool_config 340 if forward_pool_config is not None: 341 args_dict['forward_pool_config'] = forward_pool_config 342 # Don't pass the actual pool objects 343 else: 344 # For sequential processing, pass the pools directly 345 args_dict['emcee_pool'] = emcee_pool 346 args_dict['forward_pool'] = forward_pool 347 348 state_args.append(args_dict) 349 350 # Process states in parallel or sequentially 351 if state_pool is not None: 352 # Use provided state pool for parallel processing 353 results = list(state_pool.map(_process_single_state, state_args)) 354 elif n_state_processors is not None and n_state_processors > 1: 355 # Create internal ProcessPoolExecutor for state parallelism 356 with ProcessPoolExecutor(max_workers=n_state_processors) as executor: 357 results = list(executor.map(_process_single_state, state_args)) 358 else: 359 # Sequential processing (original behavior) 360 # For sequential, we can pass pools directly since no pickling needed 361 for args in state_args: 362 if 'emcee_pool_config' in args: 363 del args['emcee_pool_config'] 364 if 'forward_pool_config' in args: 365 del args['forward_pool_config'] 366 args['emcee_pool'] = emcee_pool 367 args['forward_pool'] = forward_pool 368 results = [_process_single_state(args) for args in state_args] 369 370 # Unpack results 371 samples = [result[0] for result in results] 372 log_posterior_ens = [result[1] for result in results] 373 auto_correlation = [result[2] for result in results] 374 375 if auto_thin: 376 samples, log_posterior_ens = _perform_auto_thinning( 377 samples, log_posterior_ens, auto_correlation, verbose=verbose 378 ) 379 380 return samples, log_posterior_ens
Run independent MCMC sampling within each state.
This utility function runs the emcee sampler independently within each state to generate posterior ensembles. These ensembles can then be used as input for ensemble resampling or for constructing pseudo-priors.
Parameters
n_states : int
Number of independent states in the problem.
n_dims : list of int
List of parameter dimensions for each state.
n_walkers : int or list of int
Number of random walkers for the emcee sampler. If int, same number
is used for all states. If list, specifies walkers per state.
n_steps : int or list of int
Number of MCMC steps per walker. If int, same number is used for all
states. If list, specifies steps per state.
pos : list of FloatArray
Starting positions for each state. Each array should have shape
(n_walkers[state], n_dims[state]).
log_posterior : MultiStateDensity
Function to evaluate the log-posterior density at location x in state i.
Must have signature log_posterior(x, state) -> float.
discard : int or list of int, optional
Number of samples to discard as burn-in. If int, same value used for
all states. Default is 0.
thin : int or list of int, optional
Thinning factor for chains. If int, same value used for all states.
Default is 1 (no thinning).
auto_thin : bool, optional
If True, automatically thin chains based on autocorrelation time,
ignoring the thin parameter. Default is False.
seed : int, optional
Random number seed for reproducible results. Default is 61254557.
skip_initial_state_check : bool, optional
Whether to skip emcee's initial state check. Default is False.
verbose : bool, optional
Whether to print progress information. Default is True.
state_pool : Any | None, optional
User-provided pool for parallelizing state execution. If provided, states
are processed in parallel. The pool must implement a map() method compatible
with the standard library's map() function. Default is None.
emcee_pool : Any | None, optional
User-provided pool for parallelizing emcee walker execution within each state.
The pool must implement a map() method compatible with the standard library's
map() function. Default is None.
forward_pool : Any | None, optional
User-provided pool for parallelizing forward solver calls within
log_posterior evaluations. If provided, the pool will be made available
to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context.
The pool must implement a map() method compatible with the standard library's
map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor,
and schwimmbad pools. Default is None.
**kwargs
Additional keyword arguments passed to emcee.EnsembleSampler.
Returns
ensemble_per_state : list of FloatArray Posterior samples for each state. Each array has shape (n_samples, n_dims[state]). log_posterior_ens : list of FloatArray Log posterior values for each ensemble. Each array has shape (n_samples,)
Notes
This function is primarily a convenience wrapper around emcee for generating posterior ensembles within each state independently. The resulting ensembles can be used with:
run_ensemble_resampler()for ensemble-based trans-dimensional sampling- Automatic pseudo-prior construction functions
- Direct analysis of within-state posterior distributions
If auto_thin=True, the function will automatically determine appropriate
burn-in and thinning based on the autocorrelation time, following emcee
best practices.
Examples
Basic usage:
>>> ensembles, log_probs = run_mcmc_per_state(
... n_states=2,
... n_dims=[3, 2],
... n_walkers=32,
... n_steps=1000,
... pos=[np.random.randn(32, 3), np.random.randn(32, 2)],
... log_posterior=my_log_posterior,
... auto_thin=True
... )
Using with state-level parallelism:
>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as state_pool:
... ensembles, log_probs = run_mcmc_per_state(
... n_states=4,
... n_dims=[3, 2, 4, 1],
... n_walkers=32,
... n_steps=1000,
... pos=initial_positions,
... log_posterior=my_log_posterior,
... state_pool=state_pool
... )
Using with both state and emcee parallelism:
>>> from schwimmbad import MPIPool
>>> with MPIPool() as state_pool, ProcessPoolExecutor() as emcee_pool:
... ensembles, log_probs = run_mcmc_per_state(
... n_states=2,
... n_dims=[3, 2],
... n_walkers=32,
... n_steps=1000,
... pos=initial_positions,
... log_posterior=my_log_posterior,
... state_pool=state_pool,
... emcee_pool=emcee_pool
... )
Using with forward pool for parallel forward solver calls:
>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
... ensembles, log_probs = run_mcmc_per_state(
... n_states=3,
... n_dims=[2, 3, 1],
... n_walkers=32,
... n_steps=1000,
... pos=initial_positions,
... log_posterior=my_log_posterior,
... forward_pool=forward_pool
... )
192def run_state_jump_sampler( # Independent state MCMC sampler on product space with proposal equal to pseudo prior 193 n_walkers, 194 n_steps, 195 n_states: int, 196 n_dims: list[int], 197 start_positions: list[FloatArray], 198 start_states: list[int], 199 log_posterior: MultiStateDensity, 200 log_pseudo_prior: SampleableMultiStateDensity, 201 log_proposal: ProposableMultiStateDensity, 202 prob_state=0.1, 203 seed=61254557, 204 progress=False, 205 walker_pool=None, 206 forward_pool=None, 207) -> MultiWalkerStateJumpChain: 208 """Run MCMC sampler with direct jumps between states of different states. 209 210 This function implements trans-conceptual MCMC using a Metropolis-Hastings 211 algorithm that can propose jumps between states with different numbers of 212 parameters. Between-state moves use the pseudo-prior as the proposal, while 213 within-state moves use a user-defined proposal function. 214 215 Parameters 216 ---------- 217 n_walkers : int 218 Number of random walkers used by the state jump sampler. 219 n_steps : int 220 Number of MCMC steps required per walker. 221 n_states : int 222 Number of independent states in the problem. 223 n_dims : list of int 224 List of parameter dimensions for each state. 225 start_positions : list of FloatArray 226 Starting parameter positions for each walker. Each array should contain 227 the initial parameter values for the corresponding starting state. 228 start_states : list of int 229 Starting state indices for each walker. 230 log_posterior : MultiStateDensity 231 Function to evaluate the log-posterior density at location x in state i. 232 Must have signature log_posterior(x, state) -> float. 233 log_pseudo_prior : SampleableMultiStateDensity 234 Object with methods: 235 - __call__(x, state) -> float: evaluate log pseudo-prior at x for state 236 - draw_deviate(state) -> FloatArray: sample from pseudo-prior for state 237 Note: Must be normalized over respective state spaces. 238 log_proposal : ProposableMultiStateDensity 239 Object with methods: 240 - propose(x_current, state) -> FloatArray: propose new x in state 241 - __call__(x, state) -> float: log proposal probability (for MH ratio) 242 prob_state : float, optional 243 Probability of proposing a state change per MCMC step. Otherwise, 244 a parameter change within the current state is proposed. Default is 0.1. 245 seed : int, optional 246 Random number seed for reproducible results. Default is 61254557. 247 progress : bool, optional 248 Whether to display progress information. Default is False. 249 walker_pool : Any | None, optional 250 User-provided pool for parallelizing walker execution. The pool must 251 implement a map() method compatible with the standard library's map() 252 function. Default is None. 253 forward_pool : Any | None, optional 254 User-provided pool for parallelizing forward solver calls within 255 log_posterior evaluations. If provided, the pool will be made available 256 to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. 257 The pool must implement a map() method compatible with the standard library's 258 map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 259 and schwimmbad pools. Default is None. 260 261 Returns 262 ------- 263 MultiWalkerStateJumpChain 264 Chain results containing state sequences, model parameters, proposal 265 acceptance rates, and diagnostics for all walkers. 266 267 Notes 268 ----- 269 The algorithm uses a Metropolis-Hastings sampler with two types of moves: 270 271 1. **Between-state moves** (probability `prob_state`): 272 - Propose a new state uniformly at random 273 - Generate new parameters from the pseudo-prior of the proposed state 274 - Accept/reject based on posterior and pseudo-prior ratios 275 276 2. **Within-state moves** (probability `1 - prob_state`): 277 - Use the user-defined proposal function to generate new parameters 278 - Accept/reject using standard Metropolis-Hastings criterion 279 280 The pseudo-prior must be normalized for the between-state acceptance 281 criterion to satisfy detailed balance. 282 283 Examples 284 -------- 285 Basic usage: 286 287 >>> results = run_state_jump_sampler( 288 ... n_walkers=32, 289 ... n_steps=1000, 290 ... n_states=3, 291 ... n_dims=[2, 3, 1], 292 ... start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]], 293 ... start_states=[0, 1, 2], 294 ... log_posterior=my_log_posterior, 295 ... log_pseudo_prior=my_log_pseudo_prior, 296 ... log_proposal=my_log_proposal, 297 ... prob_state=0.2 298 ... ) 299 300 Using with user-provided walker pool: 301 302 >>> from concurrent.futures import ProcessPoolExecutor 303 >>> with ProcessPoolExecutor(max_workers=4) as walker_pool: 304 ... results = run_state_jump_sampler( 305 ... n_walkers=32, 306 ... n_steps=1000, 307 ... n_states=3, 308 ... n_dims=[2, 3, 1], 309 ... start_positions=start_positions, 310 ... start_states=start_states, 311 ... log_posterior=my_log_posterior, 312 ... log_pseudo_prior=my_log_pseudo_prior, 313 ... log_proposal=my_log_proposal, 314 ... walker_pool=walker_pool 315 ... ) 316 317 Using with forward pool for parallel forward solver calls: 318 319 >>> from concurrent.futures import ProcessPoolExecutor 320 >>> with ProcessPoolExecutor(max_workers=4) as forward_pool: 321 ... results = run_state_jump_sampler( 322 ... n_walkers=32, 323 ... n_steps=1000, 324 ... n_states=3, 325 ... n_dims=[2, 3, 1], 326 ... start_positions=start_positions, 327 ... start_states=start_states, 328 ... log_posterior=my_log_posterior, 329 ... log_pseudo_prior=my_log_pseudo_prior, 330 ... log_proposal=my_log_proposal, 331 ... forward_pool=forward_pool 332 ... ) 333 """ 334 335 logger.info("Running state-jump trans-C sampler") 336 logger.info("Number of walkers: %d", n_walkers) 337 logger.info("Number of states being sampled: %d", n_states) 338 logger.info("Dimensions of each state: %s", n_dims) 339 340 # Early validation of forward pool if provided 341 if forward_pool is not None: 342 from ..utils.forward_context import set_forward_pool, clear_forward_pool 343 set_forward_pool(forward_pool) # Validates map() method 344 clear_forward_pool() # Clear after validation 345 346 random.seed(seed) 347 348 if walker_pool is not None: # put random walkers on different processors 349 chains = _run_state_jump_sampler_parallel( 350 n_walkers, 351 n_steps, 352 n_states, 353 start_positions, 354 start_states, 355 log_posterior, 356 log_pseudo_prior, 357 log_proposal, 358 prob_state=prob_state, 359 progress=progress, 360 walker_pool=walker_pool, 361 forward_pool=forward_pool, 362 ) 363 else: 364 chains = _run_state_jump_sampler_serial( 365 n_walkers, 366 n_steps, 367 n_states, 368 start_positions, 369 start_states, 370 log_posterior, 371 log_pseudo_prior, 372 log_proposal, 373 prob_state=prob_state, 374 progress=progress, 375 forward_pool=forward_pool, 376 ) 377 return MultiWalkerStateJumpChain(chains)
Run MCMC sampler with direct jumps between states of different states.
This function implements trans-conceptual MCMC using a Metropolis-Hastings algorithm that can propose jumps between states with different numbers of parameters. Between-state moves use the pseudo-prior as the proposal, while within-state moves use a user-defined proposal function.
Parameters
n_walkers : int Number of random walkers used by the state jump sampler. n_steps : int Number of MCMC steps required per walker. n_states : int Number of independent states in the problem. n_dims : list of int List of parameter dimensions for each state. start_positions : list of FloatArray Starting parameter positions for each walker. Each array should contain the initial parameter values for the corresponding starting state. start_states : list of int Starting state indices for each walker. log_posterior : MultiStateDensity Function to evaluate the log-posterior density at location x in state i. Must have signature log_posterior(x, state) -> float. log_pseudo_prior : SampleableMultiStateDensity Object with methods: - __call__(x, state) -> float: evaluate log pseudo-prior at x for state - draw_deviate(state) -> FloatArray: sample from pseudo-prior for state Note: Must be normalized over respective state spaces. log_proposal : ProposableMultiStateDensity Object with methods: - propose(x_current, state) -> FloatArray: propose new x in state - __call__(x, state) -> float: log proposal probability (for MH ratio) prob_state : float, optional Probability of proposing a state change per MCMC step. Otherwise, a parameter change within the current state is proposed. Default is 0.1. seed : int, optional Random number seed for reproducible results. Default is 61254557. progress : bool, optional Whether to display progress information. Default is False. walker_pool : Any | None, optional User-provided pool for parallelizing walker execution. The pool must implement a map() method compatible with the standard library's map() function. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None.
Returns
MultiWalkerStateJumpChain Chain results containing state sequences, model parameters, proposal acceptance rates, and diagnostics for all walkers.
Notes
The algorithm uses a Metropolis-Hastings sampler with two types of moves:
Between-state moves (probability
prob_state):- Propose a new state uniformly at random
- Generate new parameters from the pseudo-prior of the proposed state
- Accept/reject based on posterior and pseudo-prior ratios
Within-state moves (probability
1 - prob_state):- Use the user-defined proposal function to generate new parameters
- Accept/reject using standard Metropolis-Hastings criterion
The pseudo-prior must be normalized for the between-state acceptance criterion to satisfy detailed balance.
Examples
Basic usage:
>>> results = run_state_jump_sampler(
... n_walkers=32,
... n_steps=1000,
... n_states=3,
... n_dims=[2, 3, 1],
... start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]],
... start_states=[0, 1, 2],
... log_posterior=my_log_posterior,
... log_pseudo_prior=my_log_pseudo_prior,
... log_proposal=my_log_proposal,
... prob_state=0.2
... )
Using with user-provided walker pool:
>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as walker_pool:
... results = run_state_jump_sampler(
... n_walkers=32,
... n_steps=1000,
... n_states=3,
... n_dims=[2, 3, 1],
... start_positions=start_positions,
... start_states=start_states,
... log_posterior=my_log_posterior,
... log_pseudo_prior=my_log_pseudo_prior,
... log_proposal=my_log_proposal,
... walker_pool=walker_pool
... )
Using with forward pool for parallel forward solver calls:
>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
... results = run_state_jump_sampler(
... n_walkers=32,
... n_steps=1000,
... n_states=3,
... n_dims=[2, 3, 1],
... start_positions=start_positions,
... start_states=start_states,
... log_posterior=my_log_posterior,
... log_pseudo_prior=my_log_pseudo_prior,
... log_proposal=my_log_proposal,
... forward_pool=forward_pool
... )