pytransc.samplers

Sampling algorithms for pyTransC.

This module provides the main sampling algorithms for trans-conceptual MCMC:

  • Product space sampling: Fixed-dimensional sampling over the product space
  • State-jump sampling: Direct jumping between different conceptual states
  • Ensemble resampling: Resampling from pre-computed posterior ensembles
  • Per-state MCMC: Independent sampling within each state

Each sampler has different advantages and use cases depending on the problem structure and computational requirements.

 1"""Sampling algorithms for pyTransC.
 2
 3This module provides the main sampling algorithms for trans-conceptual MCMC:
 4
 5- Product space sampling: Fixed-dimensional sampling over the product space
 6- State-jump sampling: Direct jumping between different conceptual states
 7- Ensemble resampling: Resampling from pre-computed posterior ensembles
 8- Per-state MCMC: Independent sampling within each state
 9
10Each sampler has different advantages and use cases depending on the problem
11structure and computational requirements.
12"""
13
14from .ensemble_resampler import run_ensemble_resampler
15from .per_state import run_mcmc_per_state
16from .product_space import run_product_space_sampler
17from .state_jump import run_state_jump_sampler
18
19__all__ = [
20    "run_product_space_sampler",
21    "run_ensemble_resampler",
22    "run_mcmc_per_state",
23    "run_state_jump_sampler",
24]
def run_product_space_sampler( product_space: pytransc.samplers.product_space.ProductSpace, n_walkers: int, n_steps: int, start_positions: list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]], start_states: list[int], log_posterior: pytransc.utils.types.MultiStateDensity, log_pseudo_prior: pytransc.utils.types.MultiStateDensity, seed: int | None = 61254557, progress: bool = False, pool: typing.Any | None = None, forward_pool: typing.Any | None = None, **kwargs) -> pytransc.samplers.product_space.MultiWalkerProductSpaceChain:
215def run_product_space_sampler(
216    product_space: ProductSpace,
217    n_walkers: int,
218    n_steps: int,
219    start_positions: list[FloatArray],
220    start_states: list[int],
221    log_posterior: MultiStateDensity,
222    log_pseudo_prior: MultiStateDensity,
223    seed: int | None = 61254557,
224    progress: bool = False,
225    pool: Any | None = None,
226    forward_pool: Any | None = None,
227    **kwargs,
228) -> MultiWalkerProductSpaceChain:
229    """Run MCMC sampler over independent states using emcee in trans-C product space.
230
231    This function implements trans-conceptual MCMC sampling by embedding all states
232    in a fixed-dimensional product space. The sampler uses the emcee ensemble sampler
233    to explore the combined parameter space of all states.
234
235    Parameters
236    ----------
237    product_space : ProductSpace
238        The product space definition containing state dimensions.
239    n_walkers : int
240        Number of random walkers used by the product space sampler.
241    n_steps : int
242        Number of MCMC steps required per walker.
243    start_positions : list of FloatArray
244        Starting positions for walkers, one array per walker containing the
245        initial parameter values for the starting state.
246    start_states : list of int
247        Starting state indices for each walker.
248    log_posterior : MultiStateDensity
249        Function to evaluate the log-posterior density at location x in state i.
250        Must have signature log_posterior(x, state) -> float.
251    log_pseudo_prior : MultiStateDensity
252        Function to evaluate the log-pseudo-prior density at location x in state i.
253        Must have signature log_pseudo_prior(x, state) -> float.
254        Note: Must be normalized over respective state spaces.
255    seed : int, optional
256        Random number seed for reproducible results. Default is 61254557.
257    progress : bool, optional
258        Whether to display progress information. Default is False.
259    pool : Any | None, optional
260        User-provided pool for parallel processing. The pool must implement
261        a map() method compatible with the standard library's map() function.
262        Default is None.
263    forward_pool : Any | None, optional
264        User-provided pool for parallelizing forward solver calls within
265        log_posterior evaluations. If provided, the pool will be made available
266        to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context.
267        The pool must implement a map() method compatible with the standard library's 
268        map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 
269        and schwimmbad pools. Default is None.
270    **kwargs
271        Additional keyword arguments passed to the emcee sampler.
272
273    Returns
274    -------
275    MultiWalkerProductSpaceChain
276        Chain results containing state sequences, model parameters, and diagnostics
277        for all walkers.
278
279    Notes
280    -----
281    The product space approach embeds all possible states in a single fixed-dimensional
282    space. This allows the use of efficient ensemble samplers like emcee, but requires
283    sampling in a higher-dimensional space than any individual state.
284
285    Examples
286    --------
287    Basic usage:
288
289    >>> ps = ProductSpace(n_dims=[2, 3, 1])
290    >>> results = run_product_space_sampler(
291    ...     product_space=ps,
292    ...     n_walkers=32,
293    ...     n_steps=1000,
294    ...     start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]],
295    ...     start_states=[0, 1, 2],
296    ...     log_posterior=my_log_posterior,
297    ...     log_pseudo_prior=my_log_pseudo_prior
298    ... )
299
300    Using with schwimmbad pools:
301
302    >>> from schwimmbad import MPIPool
303    >>> with MPIPool() as pool:
304    ...     results = run_product_space_sampler(
305    ...         product_space=ps,
306    ...         n_walkers=32,
307    ...         n_steps=1000,
308    ...         start_positions=start_pos,
309    ...         start_states=start_states,
310    ...         log_posterior=my_log_posterior,
311    ...         log_pseudo_prior=my_log_pseudo_prior,
312    ...         pool=pool
313    ...     )
314
315    Using with forward pool for parallel forward solver calls:
316
317    >>> from concurrent.futures import ProcessPoolExecutor
318    >>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
319    ...     results = run_product_space_sampler(
320    ...         product_space=ps,
321    ...         n_walkers=32,
322    ...         n_steps=1000,
323    ...         start_positions=start_pos,
324    ...         start_states=start_states,
325    ...         log_posterior=my_log_posterior,
326    ...         log_pseudo_prior=my_log_pseudo_prior,
327    ...         forward_pool=forward_pool
328    ...     )
329    """
330
331    random.seed(seed)
332
333    # Early validation of forward pool if provided
334    if forward_pool is not None:
335        from ..utils.forward_context import set_forward_pool, clear_forward_pool
336        set_forward_pool(forward_pool)  # Validates map() method
337        clear_forward_pool()  # Clear after validation
338
339    if progress:
340        print("\nRunning product space trans-C sampler")
341        print("\nNumber of walkers               : ", n_walkers)
342        print("Number of states being sampled  : ", product_space.n_states)
343        print("Dimensions of each state        : ", product_space.n_dims)
344
345    pos_ps = _get_initial_product_space_positions(
346        n_walkers, start_states, start_positions, product_space
347    )
348
349    log_func = partial(
350        product_space_log_prob,
351        product_space=product_space,
352        log_posterior=log_posterior,
353        log_pseudo_prior=log_pseudo_prior,
354        forward_pool=forward_pool,
355    )
356
357    sampler = perform_sampling_with_emcee(
358        log_prob_func=log_func,
359        n_walkers=n_walkers,
360        n_steps=n_steps,
361        initial_state=pos_ps,
362        pool=pool,
363        progress=progress,
364        **kwargs,
365    )
366
367    return MultiWalkerProductSpaceChain.from_emcee(sampler, product_space.n_dims)

Run MCMC sampler over independent states using emcee in trans-C product space.

This function implements trans-conceptual MCMC sampling by embedding all states in a fixed-dimensional product space. The sampler uses the emcee ensemble sampler to explore the combined parameter space of all states.

Parameters

product_space : ProductSpace The product space definition containing state dimensions. n_walkers : int Number of random walkers used by the product space sampler. n_steps : int Number of MCMC steps required per walker. start_positions : list of FloatArray Starting positions for walkers, one array per walker containing the initial parameter values for the starting state. start_states : list of int Starting state indices for each walker. log_posterior : MultiStateDensity Function to evaluate the log-posterior density at location x in state i. Must have signature log_posterior(x, state) -> float. log_pseudo_prior : MultiStateDensity Function to evaluate the log-pseudo-prior density at location x in state i. Must have signature log_pseudo_prior(x, state) -> float. Note: Must be normalized over respective state spaces. seed : int, optional Random number seed for reproducible results. Default is 61254557. progress : bool, optional Whether to display progress information. Default is False. pool : Any | None, optional User-provided pool for parallel processing. The pool must implement a map() method compatible with the standard library's map() function. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None. **kwargs Additional keyword arguments passed to the emcee sampler.

Returns

MultiWalkerProductSpaceChain Chain results containing state sequences, model parameters, and diagnostics for all walkers.

Notes

The product space approach embeds all possible states in a single fixed-dimensional space. This allows the use of efficient ensemble samplers like emcee, but requires sampling in a higher-dimensional space than any individual state.

Examples

Basic usage:

>>> ps = ProductSpace(n_dims=[2, 3, 1])
>>> results = run_product_space_sampler(
...     product_space=ps,
...     n_walkers=32,
...     n_steps=1000,
...     start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]],
...     start_states=[0, 1, 2],
...     log_posterior=my_log_posterior,
...     log_pseudo_prior=my_log_pseudo_prior
... )

Using with schwimmbad pools:

>>> from schwimmbad import MPIPool
>>> with MPIPool() as pool:
...     results = run_product_space_sampler(
...         product_space=ps,
...         n_walkers=32,
...         n_steps=1000,
...         start_positions=start_pos,
...         start_states=start_states,
...         log_posterior=my_log_posterior,
...         log_pseudo_prior=my_log_pseudo_prior,
...         pool=pool
...     )

Using with forward pool for parallel forward solver calls:

>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
...     results = run_product_space_sampler(
...         product_space=ps,
...         n_walkers=32,
...         n_steps=1000,
...         start_positions=start_pos,
...         start_states=start_states,
...         log_posterior=my_log_posterior,
...         log_pseudo_prior=my_log_pseudo_prior,
...         forward_pool=forward_pool
...     )
def run_ensemble_resampler( n_walkers, n_steps, n_states: int, n_dims: list[int], log_posterior_ens: list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]], log_pseudo_prior_ens: list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]], seed=61254557, state_proposal_weights: list[list[float]] | None = None, progress=False, walker_pool=None, state_pool=None, forward_pool=None) -> pytransc.samplers.ensemble_resampler.MultiWalkerEnsembleResamplerChain:
157def run_ensemble_resampler(  # Independent state Marginal Likelihoods from pre-computed posterior and pseudo prior ensembles
158    n_walkers,
159    n_steps,
160    n_states: int,
161    n_dims: list[int],
162    log_posterior_ens: StateOrderedEnsemble,
163    log_pseudo_prior_ens: StateOrderedEnsemble,
164    seed=61254557,
165    state_proposal_weights: list[list[float]] | None = None,
166    progress=False,
167    walker_pool=None,
168    state_pool=None,
169    forward_pool=None,
170) -> MultiWalkerEnsembleResamplerChain:
171    """Run MCMC sampler over independent states using pre-computed ensembles.
172
173    This function performs trans-conceptual MCMC by resampling from pre-computed
174    posterior ensembles in each state. It calculates relative evidence of each state
175    by sampling over the ensemble members according to their posterior and pseudo-prior
176    densities.
177
178    Parameters
179    ----------
180    n_walkers : int
181        Number of random walkers used by the ensemble resampler.
182    n_steps : int
183        Number of Markov chain steps to perform per walker.
184    n_states : int
185        Number of independent states in the problem.
186    n_dims : list of int
187        List of parameter dimensions for each state.
188    log_posterior_ens : StateOrderedEnsemble
189        Log-posterior values of ensemble members in each state.
190        Format: list of arrays, where each array contains log-posterior values
191        for the ensemble members in that state.
192    log_pseudo_prior_ens : StateOrderedEnsemble
193        Log-pseudo-prior values of ensemble members in each state.
194        Format: list of arrays, where each array contains log-pseudo-prior values
195        for the ensemble members in that state.
196    seed : int, optional
197        Random number seed for reproducible results. Default is 61254557.
198    state_proposal_weights : list of list of float, optional
199        Weights for proposing transitions between states. Should be a matrix
200        where element [i][j] is the weight for proposing state j from state i.
201        Diagonal elements are ignored. If None, uniform weights are used.
202    progress : bool, optional
203        Whether to display progress information. Default is False.
204    walker_pool : Any | None, optional
205        User-provided pool for parallelizing walker execution. The pool must
206        implement a map() method compatible with the standard library's map()
207        function. Default is None.
208    state_pool : Any | None, optional
209        User-provided pool for parallelizing state-level operations such as
210        pseudo-prior evaluation across states. Currently reserved for future
211        enhancements. Default is None.
212    forward_pool : Any | None, optional
213        User-provided pool for parallelizing forward solver calls within
214        log_posterior evaluations. If provided, the pool will be made available
215        to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context.
216        The pool must implement a map() method compatible with the standard library's 
217        map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 
218        and schwimmbad pools. Default is None.
219
220    Returns
221    -------
222    MultiWalkerEnsembleResamplerChain
223        Chain results containing state sequences, ensemble member indices,
224        and diagnostics for all walkers.
225
226    Notes
227    -----
228    This method requires pre-computed posterior ensembles and their corresponding
229    log-density values. The ensembles can be generated using `run_mcmc_per_state()`
230    and the pseudo-prior values using automatic fitting routines.
231
232    The algorithm works by:
233    1. Selecting ensemble members within states based on posterior weights
234    2. Proposing transitions between states based on relative evidence
235    3. Accepting/rejecting proposals using Metropolis-Hastings criterion
236
237    Examples
238    --------
239    >>> results = run_ensemble_resampler(
240    ...     n_walkers=32,
241    ...     n_steps=1000,
242    ...     n_states=3,
243    ...     n_dims=[2, 3, 1],
244    ...     log_posterior_ens=posterior_ensembles,
245    ...     log_pseudo_prior_ens=pseudo_prior_ensembles
246    ... )
247
248    Using with forward pool for parallel forward solver calls:
249
250    >>> from concurrent.futures import ProcessPoolExecutor
251    >>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
252    ...     results = run_ensemble_resampler(
253    ...         n_walkers=32,
254    ...         n_steps=1000,
255    ...         n_states=3,
256    ...         n_dims=[2, 3, 1],
257    ...         log_posterior_ens=posterior_ensembles,
258    ...         log_pseudo_prior_ens=pseudo_prior_ensembles,
259    ...         forward_pool=forward_pool
260    ...     )
261    """
262
263    n_samples = [len(log_post_ens) for log_post_ens in log_posterior_ens]
264
265    # Early validation of forward pool if provided
266    if forward_pool is not None:
267        from ..utils.forward_context import set_forward_pool, clear_forward_pool
268        set_forward_pool(forward_pool)  # Validates map() method
269        clear_forward_pool()  # Clear after validation
270
271    if state_proposal_weights is None:
272        # uniform proposal weights
273        _state_proposal_weights = [[1.0] * n_states] * n_states
274    else:
275        _state_proposal_weights = np.array(state_proposal_weights)
276        np.fill_diagonal(_state_proposal_weights, 0.0)  # ensure diagonal is zero
277        _state_proposal_weights = _state_proposal_weights / _state_proposal_weights.sum(
278            axis=1, keepdims=True
279        )  # set row sums to unity
280        _state_proposal_weights = _state_proposal_weights.tolist()
281
282    logger.info("\nRunning ensemble resampler")
283    logger.info("\nNumber of walkers               : %d", n_walkers)
284    logger.info("Number of states being sampled  : %d", n_states)
285    logger.info("Dimensions of each state        : %s", n_dims)
286
287    random.seed(seed)
288    if walker_pool is not None:
289        chains = _run_mcmc_walker_parallel(
290            n_walkers,
291            n_states,
292            n_samples,
293            n_steps,
294            log_posterior_ens,
295            log_pseudo_prior_ens,
296            state_proposal_weights=_state_proposal_weights,
297            progress=progress,
298            walker_pool=walker_pool,
299            forward_pool=forward_pool,
300        )
301
302    else:
303        chains = _run_mcmc_walker_serial(
304            n_walkers,
305            n_states,
306            n_samples,
307            n_steps,
308            log_posterior_ens,
309            log_pseudo_prior_ens,
310            state_proposal_weights=_state_proposal_weights,
311            progress=progress,
312            forward_pool=forward_pool,
313        )
314
315    return MultiWalkerEnsembleResamplerChain(chains)

Run MCMC sampler over independent states using pre-computed ensembles.

This function performs trans-conceptual MCMC by resampling from pre-computed posterior ensembles in each state. It calculates relative evidence of each state by sampling over the ensemble members according to their posterior and pseudo-prior densities.

Parameters

n_walkers : int Number of random walkers used by the ensemble resampler. n_steps : int Number of Markov chain steps to perform per walker. n_states : int Number of independent states in the problem. n_dims : list of int List of parameter dimensions for each state. log_posterior_ens : StateOrderedEnsemble Log-posterior values of ensemble members in each state. Format: list of arrays, where each array contains log-posterior values for the ensemble members in that state. log_pseudo_prior_ens : StateOrderedEnsemble Log-pseudo-prior values of ensemble members in each state. Format: list of arrays, where each array contains log-pseudo-prior values for the ensemble members in that state. seed : int, optional Random number seed for reproducible results. Default is 61254557. state_proposal_weights : list of list of float, optional Weights for proposing transitions between states. Should be a matrix where element [i][j] is the weight for proposing state j from state i. Diagonal elements are ignored. If None, uniform weights are used. progress : bool, optional Whether to display progress information. Default is False. walker_pool : Any | None, optional User-provided pool for parallelizing walker execution. The pool must implement a map() method compatible with the standard library's map() function. Default is None. state_pool : Any | None, optional User-provided pool for parallelizing state-level operations such as pseudo-prior evaluation across states. Currently reserved for future enhancements. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None.

Returns

MultiWalkerEnsembleResamplerChain Chain results containing state sequences, ensemble member indices, and diagnostics for all walkers.

Notes

This method requires pre-computed posterior ensembles and their corresponding log-density values. The ensembles can be generated using run_mcmc_per_state() and the pseudo-prior values using automatic fitting routines.

The algorithm works by:

  1. Selecting ensemble members within states based on posterior weights
  2. Proposing transitions between states based on relative evidence
  3. Accepting/rejecting proposals using Metropolis-Hastings criterion

Examples

>>> results = run_ensemble_resampler(
...     n_walkers=32,
...     n_steps=1000,
...     n_states=3,
...     n_dims=[2, 3, 1],
...     log_posterior_ens=posterior_ensembles,
...     log_pseudo_prior_ens=pseudo_prior_ensembles
... )

Using with forward pool for parallel forward solver calls:

>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
...     results = run_ensemble_resampler(
...         n_walkers=32,
...         n_steps=1000,
...         n_states=3,
...         n_dims=[2, 3, 1],
...         log_posterior_ens=posterior_ensembles,
...         log_pseudo_prior_ens=pseudo_prior_ensembles,
...         forward_pool=forward_pool
...     )
def run_mcmc_per_state( n_states: int, n_dims: list[int], n_walkers: int | list[int], n_steps: int | list[int], pos: list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]], log_posterior: pytransc.utils.types.MultiStateDensity, discard: int | list[int] = 0, thin: int | list[int] = 1, auto_thin: bool = False, seed: int = 61254557, state_pool: typing.Any | None = None, emcee_pool: typing.Any | None = None, n_state_processors: int | None = None, skip_initial_state_check: bool = False, verbose: bool = True, forward_pool: typing.Any | None = None, **kwargs) -> tuple[list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]], list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]]]:
 97def run_mcmc_per_state(
 98    n_states: int,
 99    n_dims: list[int],
100    n_walkers: int | list[int],
101    n_steps: int | list[int],
102    pos: list[FloatArray],
103    log_posterior: MultiStateDensity,
104    discard: int | list[int] = 0,
105    thin: int | list[int] = 1,
106    auto_thin: bool = False,
107    seed: int = 61254557,
108    state_pool: Any | None = None,
109    emcee_pool: Any | None = None,
110    n_state_processors: int | None = None,
111    skip_initial_state_check: bool = False,
112    verbose: bool = True,
113    forward_pool: Any | None = None,
114    **kwargs,
115) -> tuple[list[FloatArray], list[FloatArray]]:
116    """Run independent MCMC sampling within each state.
117
118    This utility function runs the emcee sampler independently within each state
119    to generate posterior ensembles. These ensembles can then be used as input
120    for ensemble resampling or for constructing pseudo-priors.
121
122    Parameters
123    ----------
124    n_states : int
125        Number of independent states in the problem.
126    n_dims : list of int
127        List of parameter dimensions for each state.
128    n_walkers : int or list of int
129        Number of random walkers for the emcee sampler. If int, same number
130        is used for all states. If list, specifies walkers per state.
131    n_steps : int or list of int
132        Number of MCMC steps per walker. If int, same number is used for all
133        states. If list, specifies steps per state.
134    pos : list of FloatArray
135        Starting positions for each state. Each array should have shape
136        (n_walkers[state], n_dims[state]).
137    log_posterior : MultiStateDensity
138        Function to evaluate the log-posterior density at location x in state i.
139        Must have signature log_posterior(x, state) -> float.
140    discard : int or list of int, optional
141        Number of samples to discard as burn-in. If int, same value used for
142        all states. Default is 0.
143    thin : int or list of int, optional
144        Thinning factor for chains. If int, same value used for all states.
145        Default is 1 (no thinning).
146    auto_thin : bool, optional
147        If True, automatically thin chains based on autocorrelation time,
148        ignoring the `thin` parameter. Default is False.
149    seed : int, optional
150        Random number seed for reproducible results. Default is 61254557.
151    skip_initial_state_check : bool, optional
152        Whether to skip emcee's initial state check. Default is False.
153    verbose : bool, optional
154        Whether to print progress information. Default is True.
155    state_pool : Any | None, optional
156        User-provided pool for parallelizing state execution. If provided, states
157        are processed in parallel. The pool must implement a map() method compatible
158        with the standard library's map() function. Default is None.
159    emcee_pool : Any | None, optional
160        User-provided pool for parallelizing emcee walker execution within each state.
161        The pool must implement a map() method compatible with the standard library's
162        map() function. Default is None.
163    forward_pool : Any | None, optional
164        User-provided pool for parallelizing forward solver calls within
165        log_posterior evaluations. If provided, the pool will be made available
166        to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context.
167        The pool must implement a map() method compatible with the standard library's 
168        map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 
169        and schwimmbad pools. Default is None.
170    **kwargs
171        Additional keyword arguments passed to emcee.EnsembleSampler.
172
173    Returns
174    -------
175    -------
176    ensemble_per_state : list of FloatArray
177        Posterior samples for each state. Each array has shape
178        (n_samples, n_dims[state]).
179    log_posterior_ens : list of FloatArray
180        Log posterior values for each ensemble. Each array has shape (n_samples,)
181
182    Notes
183    -----
184    This function is primarily a convenience wrapper around emcee for generating
185    posterior ensembles within each state independently. The resulting ensembles
186    can be used with:
187
188    - `run_ensemble_resampler()` for ensemble-based trans-dimensional sampling
189    - Automatic pseudo-prior construction functions
190    - Direct analysis of within-state posterior distributions
191
192    If `auto_thin=True`, the function will automatically determine appropriate
193    burn-in and thinning based on the autocorrelation time, following emcee
194    best practices.
195
196    Examples
197    --------
198    Basic usage:
199
200    >>> ensembles, log_probs = run_mcmc_per_state(
201    ...     n_states=2,
202    ...     n_dims=[3, 2],
203    ...     n_walkers=32,
204    ...     n_steps=1000,
205    ...     pos=[np.random.randn(32, 3), np.random.randn(32, 2)],
206    ...     log_posterior=my_log_posterior,
207    ...     auto_thin=True
208    ... )
209
210    Using with state-level parallelism:
211
212    >>> from concurrent.futures import ProcessPoolExecutor
213    >>> with ProcessPoolExecutor(max_workers=4) as state_pool:
214    ...     ensembles, log_probs = run_mcmc_per_state(
215    ...         n_states=4,
216    ...         n_dims=[3, 2, 4, 1],
217    ...         n_walkers=32,
218    ...         n_steps=1000,
219    ...         pos=initial_positions,
220    ...         log_posterior=my_log_posterior,
221    ...         state_pool=state_pool
222    ...     )
223    
224    Using with both state and emcee parallelism:
225    
226    >>> from schwimmbad import MPIPool
227    >>> with MPIPool() as state_pool, ProcessPoolExecutor() as emcee_pool:
228    ...     ensembles, log_probs = run_mcmc_per_state(
229    ...         n_states=2,
230    ...         n_dims=[3, 2],
231    ...         n_walkers=32,
232    ...         n_steps=1000,
233    ...         pos=initial_positions,
234    ...         log_posterior=my_log_posterior,
235    ...         state_pool=state_pool,
236    ...         emcee_pool=emcee_pool
237    ...     )
238
239    Using with forward pool for parallel forward solver calls:
240
241    >>> from concurrent.futures import ProcessPoolExecutor
242    >>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
243    ...     ensembles, log_probs = run_mcmc_per_state(
244    ...         n_states=3,
245    ...         n_dims=[2, 3, 1],
246    ...         n_walkers=32,
247    ...         n_steps=1000,
248    ...         pos=initial_positions,
249    ...         log_posterior=my_log_posterior,
250    ...         forward_pool=forward_pool
251    ...     )
252    """
253
254    random.seed(seed)
255
256    # Early validation of forward pool if provided
257    if forward_pool is not None:
258        from ..utils.forward_context import set_forward_pool, clear_forward_pool
259        set_forward_pool(forward_pool)  # Validates map() method
260        clear_forward_pool()  # Clear after validation
261
262    if not isinstance(n_walkers, list):
263        n_walkers = [n_walkers] * n_states
264    if not isinstance(discard, list):
265        discard = [discard] * n_states
266    if not isinstance(thin, list):
267        thin = [thin] * n_states
268    if not isinstance(n_steps, list):
269        n_steps = [n_steps] * n_states
270
271    if auto_thin:
272        # ignore thinning factor because we are post thinning by the auto-correlation times
273        # burn_in is also calculated from the auto-correlation time
274        thin = [1] * n_states
275        discard = [0] * n_states
276
277    if verbose:
278        print("\nRunning within-state sampler separately on each state")
279        print("\nNumber of walkers               : ", n_walkers)
280        print("\nNumber of states being sampled: ", n_states)
281        print("Dimensions of each state: ", n_dims)
282        if state_pool is not None:
283            print("Using state-level parallelism")
284        if emcee_pool is not None:
285            print("Using walker-level parallelism")
286
287    # Prepare emcee pool configuration to avoid pickling issues
288    emcee_pool_config = None
289    if emcee_pool is not None:
290        # Determine pool type and configuration
291        if hasattr(emcee_pool, '__class__'):
292            pool_class_name = emcee_pool.__class__.__name__
293            if pool_class_name == 'ProcessPoolExecutor':
294                emcee_pool_config = {
295                    'type': 'ProcessPoolExecutor',
296                    'kwargs': {'max_workers': emcee_pool._max_workers}
297                }
298            elif pool_class_name == 'ThreadPoolExecutor':
299                emcee_pool_config = {
300                    'type': 'ThreadPoolExecutor',
301                    'kwargs': {'max_workers': emcee_pool._max_workers}
302                }
303    
304    # Prepare forward pool configuration to avoid pickling issues
305    forward_pool_config = None
306    if forward_pool is not None and state_pool is not None:
307        # Only create config when using state-level parallelism
308        if hasattr(forward_pool, '__class__'):
309            pool_class_name = forward_pool.__class__.__name__
310            if pool_class_name == 'ProcessPoolExecutor':
311                forward_pool_config = {
312                    'type': 'ProcessPoolExecutor',
313                    'kwargs': {'max_workers': forward_pool._max_workers}
314                }
315            elif pool_class_name == 'ThreadPoolExecutor':
316                forward_pool_config = {
317                    'type': 'ThreadPoolExecutor',
318                    'kwargs': {'max_workers': forward_pool._max_workers}
319                }
320
321    # Prepare state processing arguments
322    state_args = []
323    for i in range(n_states):
324        args_dict = {
325            'state_idx': i,
326            'log_posterior': log_posterior,
327            'n_walkers': n_walkers[i],
328            'pos': pos[i],
329            'n_steps': n_steps[i],
330            'discard': discard[i],
331            'thin': thin[i],
332            'skip_initial_state_check': skip_initial_state_check,
333            'verbose': verbose,
334            **kwargs
335        }
336
337        # Add pool configs for state-level parallelism
338        if state_pool is not None:
339            args_dict['emcee_pool_config'] = emcee_pool_config
340            if forward_pool_config is not None:
341                args_dict['forward_pool_config'] = forward_pool_config
342            # Don't pass the actual pool objects
343        else:
344            # For sequential processing, pass the pools directly
345            args_dict['emcee_pool'] = emcee_pool
346            args_dict['forward_pool'] = forward_pool
347
348        state_args.append(args_dict)
349
350    # Process states in parallel or sequentially
351    if state_pool is not None:
352        # Use provided state pool for parallel processing
353        results = list(state_pool.map(_process_single_state, state_args))
354    elif n_state_processors is not None and n_state_processors > 1:
355        # Create internal ProcessPoolExecutor for state parallelism
356        with ProcessPoolExecutor(max_workers=n_state_processors) as executor:
357            results = list(executor.map(_process_single_state, state_args))
358    else:
359        # Sequential processing (original behavior)
360        # For sequential, we can pass pools directly since no pickling needed
361        for args in state_args:
362            if 'emcee_pool_config' in args:
363                del args['emcee_pool_config']
364            if 'forward_pool_config' in args:
365                del args['forward_pool_config']
366            args['emcee_pool'] = emcee_pool
367            args['forward_pool'] = forward_pool
368        results = [_process_single_state(args) for args in state_args]
369
370    # Unpack results
371    samples = [result[0] for result in results]
372    log_posterior_ens = [result[1] for result in results]
373    auto_correlation = [result[2] for result in results]
374
375    if auto_thin:
376        samples, log_posterior_ens = _perform_auto_thinning(
377            samples, log_posterior_ens, auto_correlation, verbose=verbose
378        )
379
380    return samples, log_posterior_ens

Run independent MCMC sampling within each state.

This utility function runs the emcee sampler independently within each state to generate posterior ensembles. These ensembles can then be used as input for ensemble resampling or for constructing pseudo-priors.

Parameters

n_states : int Number of independent states in the problem. n_dims : list of int List of parameter dimensions for each state. n_walkers : int or list of int Number of random walkers for the emcee sampler. If int, same number is used for all states. If list, specifies walkers per state. n_steps : int or list of int Number of MCMC steps per walker. If int, same number is used for all states. If list, specifies steps per state. pos : list of FloatArray Starting positions for each state. Each array should have shape (n_walkers[state], n_dims[state]). log_posterior : MultiStateDensity Function to evaluate the log-posterior density at location x in state i. Must have signature log_posterior(x, state) -> float. discard : int or list of int, optional Number of samples to discard as burn-in. If int, same value used for all states. Default is 0. thin : int or list of int, optional Thinning factor for chains. If int, same value used for all states. Default is 1 (no thinning). auto_thin : bool, optional If True, automatically thin chains based on autocorrelation time, ignoring the thin parameter. Default is False. seed : int, optional Random number seed for reproducible results. Default is 61254557. skip_initial_state_check : bool, optional Whether to skip emcee's initial state check. Default is False. verbose : bool, optional Whether to print progress information. Default is True. state_pool : Any | None, optional User-provided pool for parallelizing state execution. If provided, states are processed in parallel. The pool must implement a map() method compatible with the standard library's map() function. Default is None. emcee_pool : Any | None, optional User-provided pool for parallelizing emcee walker execution within each state. The pool must implement a map() method compatible with the standard library's map() function. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None. **kwargs Additional keyword arguments passed to emcee.EnsembleSampler.

Returns


ensemble_per_state : list of FloatArray Posterior samples for each state. Each array has shape (n_samples, n_dims[state]). log_posterior_ens : list of FloatArray Log posterior values for each ensemble. Each array has shape (n_samples,)

Notes

This function is primarily a convenience wrapper around emcee for generating posterior ensembles within each state independently. The resulting ensembles can be used with:

  • run_ensemble_resampler() for ensemble-based trans-dimensional sampling
  • Automatic pseudo-prior construction functions
  • Direct analysis of within-state posterior distributions

If auto_thin=True, the function will automatically determine appropriate burn-in and thinning based on the autocorrelation time, following emcee best practices.

Examples

Basic usage:

>>> ensembles, log_probs = run_mcmc_per_state(
...     n_states=2,
...     n_dims=[3, 2],
...     n_walkers=32,
...     n_steps=1000,
...     pos=[np.random.randn(32, 3), np.random.randn(32, 2)],
...     log_posterior=my_log_posterior,
...     auto_thin=True
... )

Using with state-level parallelism:

>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as state_pool:
...     ensembles, log_probs = run_mcmc_per_state(
...         n_states=4,
...         n_dims=[3, 2, 4, 1],
...         n_walkers=32,
...         n_steps=1000,
...         pos=initial_positions,
...         log_posterior=my_log_posterior,
...         state_pool=state_pool
...     )

Using with both state and emcee parallelism:

>>> from schwimmbad import MPIPool
>>> with MPIPool() as state_pool, ProcessPoolExecutor() as emcee_pool:
...     ensembles, log_probs = run_mcmc_per_state(
...         n_states=2,
...         n_dims=[3, 2],
...         n_walkers=32,
...         n_steps=1000,
...         pos=initial_positions,
...         log_posterior=my_log_posterior,
...         state_pool=state_pool,
...         emcee_pool=emcee_pool
...     )

Using with forward pool for parallel forward solver calls:

>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
...     ensembles, log_probs = run_mcmc_per_state(
...         n_states=3,
...         n_dims=[2, 3, 1],
...         n_walkers=32,
...         n_steps=1000,
...         pos=initial_positions,
...         log_posterior=my_log_posterior,
...         forward_pool=forward_pool
...     )
def run_state_jump_sampler( n_walkers, n_steps, n_states: int, n_dims: list[int], start_positions: list[numpy.ndarray[tuple[typing.Any, ...], numpy.dtype[numpy.floating]]], start_states: list[int], log_posterior: pytransc.utils.types.MultiStateDensity, log_pseudo_prior: pytransc.utils.types.SampleableMultiStateDensity, log_proposal: pytransc.utils.types.ProposableMultiStateDensity, prob_state=0.1, seed=61254557, progress=False, walker_pool=None, forward_pool=None) -> pytransc.samplers.state_jump.MultiWalkerStateJumpChain:
192def run_state_jump_sampler(  # Independent state MCMC sampler on product space with proposal equal to pseudo prior
193    n_walkers,
194    n_steps,
195    n_states: int,
196    n_dims: list[int],
197    start_positions: list[FloatArray],
198    start_states: list[int],
199    log_posterior: MultiStateDensity,
200    log_pseudo_prior: SampleableMultiStateDensity,
201    log_proposal: ProposableMultiStateDensity,
202    prob_state=0.1,
203    seed=61254557,
204    progress=False,
205    walker_pool=None,
206    forward_pool=None,
207) -> MultiWalkerStateJumpChain:
208    """Run MCMC sampler with direct jumps between states of different states.
209
210    This function implements trans-conceptual MCMC using a Metropolis-Hastings
211    algorithm that can propose jumps between states with different numbers of
212    parameters. Between-state moves use the pseudo-prior as the proposal, while
213    within-state moves use a user-defined proposal function.
214
215    Parameters
216    ----------
217    n_walkers : int
218        Number of random walkers used by the state jump sampler.
219    n_steps : int
220        Number of MCMC steps required per walker.
221    n_states : int
222        Number of independent states in the problem.
223    n_dims : list of int
224        List of parameter dimensions for each state.
225    start_positions : list of FloatArray
226        Starting parameter positions for each walker. Each array should contain
227        the initial parameter values for the corresponding starting state.
228    start_states : list of int
229        Starting state indices for each walker.
230    log_posterior : MultiStateDensity
231        Function to evaluate the log-posterior density at location x in state i.
232        Must have signature log_posterior(x, state) -> float.
233    log_pseudo_prior : SampleableMultiStateDensity
234        Object with methods:
235        - __call__(x, state) -> float: evaluate log pseudo-prior at x for state
236        - draw_deviate(state) -> FloatArray: sample from pseudo-prior for state
237        Note: Must be normalized over respective state spaces.
238    log_proposal : ProposableMultiStateDensity
239        Object with methods:
240        - propose(x_current, state) -> FloatArray: propose new x in state
241        - __call__(x, state) -> float: log proposal probability (for MH ratio)
242    prob_state : float, optional
243        Probability of proposing a state change per MCMC step. Otherwise,
244        a parameter change within the current state is proposed. Default is 0.1.
245    seed : int, optional
246        Random number seed for reproducible results. Default is 61254557.
247    progress : bool, optional
248        Whether to display progress information. Default is False.
249    walker_pool : Any | None, optional
250        User-provided pool for parallelizing walker execution. The pool must
251        implement a map() method compatible with the standard library's map()
252        function. Default is None.
253    forward_pool : Any | None, optional
254        User-provided pool for parallelizing forward solver calls within
255        log_posterior evaluations. If provided, the pool will be made available
256        to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context.
257        The pool must implement a map() method compatible with the standard library's 
258        map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, 
259        and schwimmbad pools. Default is None.
260
261    Returns
262    -------
263    MultiWalkerStateJumpChain
264        Chain results containing state sequences, model parameters, proposal
265        acceptance rates, and diagnostics for all walkers.
266
267    Notes
268    -----
269    The algorithm uses a Metropolis-Hastings sampler with two types of moves:
270
271    1. **Between-state moves** (probability `prob_state`):
272       - Propose a new state uniformly at random
273       - Generate new parameters from the pseudo-prior of the proposed state
274       - Accept/reject based on posterior and pseudo-prior ratios
275
276    2. **Within-state moves** (probability `1 - prob_state`):
277       - Use the user-defined proposal function to generate new parameters
278       - Accept/reject using standard Metropolis-Hastings criterion
279
280    The pseudo-prior must be normalized for the between-state acceptance
281    criterion to satisfy detailed balance.
282
283    Examples
284    --------
285    Basic usage:
286    
287    >>> results = run_state_jump_sampler(
288    ...     n_walkers=32,
289    ...     n_steps=1000,
290    ...     n_states=3,
291    ...     n_dims=[2, 3, 1],
292    ...     start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]],
293    ...     start_states=[0, 1, 2],
294    ...     log_posterior=my_log_posterior,
295    ...     log_pseudo_prior=my_log_pseudo_prior,
296    ...     log_proposal=my_log_proposal,
297    ...     prob_state=0.2
298    ... )
299
300    Using with user-provided walker pool:
301
302    >>> from concurrent.futures import ProcessPoolExecutor
303    >>> with ProcessPoolExecutor(max_workers=4) as walker_pool:
304    ...     results = run_state_jump_sampler(
305    ...         n_walkers=32,
306    ...         n_steps=1000,
307    ...         n_states=3,
308    ...         n_dims=[2, 3, 1],
309    ...         start_positions=start_positions,
310    ...         start_states=start_states,
311    ...         log_posterior=my_log_posterior,
312    ...         log_pseudo_prior=my_log_pseudo_prior,
313    ...         log_proposal=my_log_proposal,
314    ...         walker_pool=walker_pool
315    ...     )
316
317    Using with forward pool for parallel forward solver calls:
318    
319    >>> from concurrent.futures import ProcessPoolExecutor
320    >>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
321    ...     results = run_state_jump_sampler(
322    ...         n_walkers=32,
323    ...         n_steps=1000,
324    ...         n_states=3,
325    ...         n_dims=[2, 3, 1],
326    ...         start_positions=start_positions,
327    ...         start_states=start_states,
328    ...         log_posterior=my_log_posterior,
329    ...         log_pseudo_prior=my_log_pseudo_prior,
330    ...         log_proposal=my_log_proposal,
331    ...         forward_pool=forward_pool
332    ...     )
333    """
334
335    logger.info("Running state-jump trans-C sampler")
336    logger.info("Number of walkers: %d", n_walkers)
337    logger.info("Number of states being sampled: %d", n_states)
338    logger.info("Dimensions of each state: %s", n_dims)
339
340    # Early validation of forward pool if provided
341    if forward_pool is not None:
342        from ..utils.forward_context import set_forward_pool, clear_forward_pool
343        set_forward_pool(forward_pool)  # Validates map() method
344        clear_forward_pool()  # Clear after validation
345
346    random.seed(seed)
347
348    if walker_pool is not None:  # put random walkers on different processors
349        chains = _run_state_jump_sampler_parallel(
350            n_walkers,
351            n_steps,
352            n_states,
353            start_positions,
354            start_states,
355            log_posterior,
356            log_pseudo_prior,
357            log_proposal,
358            prob_state=prob_state,
359            progress=progress,
360            walker_pool=walker_pool,
361            forward_pool=forward_pool,
362        )
363    else:
364        chains = _run_state_jump_sampler_serial(
365            n_walkers,
366            n_steps,
367            n_states,
368            start_positions,
369            start_states,
370            log_posterior,
371            log_pseudo_prior,
372            log_proposal,
373            prob_state=prob_state,
374            progress=progress,
375            forward_pool=forward_pool,
376        )
377    return MultiWalkerStateJumpChain(chains)

Run MCMC sampler with direct jumps between states of different states.

This function implements trans-conceptual MCMC using a Metropolis-Hastings algorithm that can propose jumps between states with different numbers of parameters. Between-state moves use the pseudo-prior as the proposal, while within-state moves use a user-defined proposal function.

Parameters

n_walkers : int Number of random walkers used by the state jump sampler. n_steps : int Number of MCMC steps required per walker. n_states : int Number of independent states in the problem. n_dims : list of int List of parameter dimensions for each state. start_positions : list of FloatArray Starting parameter positions for each walker. Each array should contain the initial parameter values for the corresponding starting state. start_states : list of int Starting state indices for each walker. log_posterior : MultiStateDensity Function to evaluate the log-posterior density at location x in state i. Must have signature log_posterior(x, state) -> float. log_pseudo_prior : SampleableMultiStateDensity Object with methods: - __call__(x, state) -> float: evaluate log pseudo-prior at x for state - draw_deviate(state) -> FloatArray: sample from pseudo-prior for state Note: Must be normalized over respective state spaces. log_proposal : ProposableMultiStateDensity Object with methods: - propose(x_current, state) -> FloatArray: propose new x in state - __call__(x, state) -> float: log proposal probability (for MH ratio) prob_state : float, optional Probability of proposing a state change per MCMC step. Otherwise, a parameter change within the current state is proposed. Default is 0.1. seed : int, optional Random number seed for reproducible results. Default is 61254557. progress : bool, optional Whether to display progress information. Default is False. walker_pool : Any | None, optional User-provided pool for parallelizing walker execution. The pool must implement a map() method compatible with the standard library's map() function. Default is None. forward_pool : Any | None, optional User-provided pool for parallelizing forward solver calls within log_posterior evaluations. If provided, the pool will be made available to log_posterior functions via get_forward_pool() from pytransc.utils.forward_context. The pool must implement a map() method compatible with the standard library's map() function. Supports ProcessPoolExecutor, ThreadPoolExecutor, and schwimmbad pools. Default is None.

Returns

MultiWalkerStateJumpChain Chain results containing state sequences, model parameters, proposal acceptance rates, and diagnostics for all walkers.

Notes

The algorithm uses a Metropolis-Hastings sampler with two types of moves:

  1. Between-state moves (probability prob_state):

    • Propose a new state uniformly at random
    • Generate new parameters from the pseudo-prior of the proposed state
    • Accept/reject based on posterior and pseudo-prior ratios
  2. Within-state moves (probability 1 - prob_state):

    • Use the user-defined proposal function to generate new parameters
    • Accept/reject using standard Metropolis-Hastings criterion

The pseudo-prior must be normalized for the between-state acceptance criterion to satisfy detailed balance.

Examples

Basic usage:

>>> results = run_state_jump_sampler(
...     n_walkers=32,
...     n_steps=1000,
...     n_states=3,
...     n_dims=[2, 3, 1],
...     start_positions=[[0.5, 0.5], [1.0, 0.0, -1.0], [2.0]],
...     start_states=[0, 1, 2],
...     log_posterior=my_log_posterior,
...     log_pseudo_prior=my_log_pseudo_prior,
...     log_proposal=my_log_proposal,
...     prob_state=0.2
... )

Using with user-provided walker pool:

>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as walker_pool:
...     results = run_state_jump_sampler(
...         n_walkers=32,
...         n_steps=1000,
...         n_states=3,
...         n_dims=[2, 3, 1],
...         start_positions=start_positions,
...         start_states=start_states,
...         log_posterior=my_log_posterior,
...         log_pseudo_prior=my_log_pseudo_prior,
...         log_proposal=my_log_proposal,
...         walker_pool=walker_pool
...     )

Using with forward pool for parallel forward solver calls:

>>> from concurrent.futures import ProcessPoolExecutor
>>> with ProcessPoolExecutor(max_workers=4) as forward_pool:
...     results = run_state_jump_sampler(
...         n_walkers=32,
...         n_steps=1000,
...         n_states=3,
...         n_dims=[2, 3, 1],
...         start_positions=start_positions,
...         start_states=start_states,
...         log_posterior=my_log_posterior,
...         log_pseudo_prior=my_log_pseudo_prior,
...         log_proposal=my_log_proposal,
...         forward_pool=forward_pool
...     )