diff --git a/pymc/backends/ndarray.py b/pymc/backends/ndarray.py index a08fc8f47e..326fe1e3f3 100644 --- a/pymc/backends/ndarray.py +++ b/pymc/backends/ndarray.py @@ -113,6 +113,10 @@ def record(self, point, sampler_stats=None) -> None: if sampler_stats is not None: for data, vars in zip(self._stats, sampler_stats): for key, val in vars.items(): + # step_meta is a key used by the progress bars to track which draw came from which step instance. It + # should never be stored as a sampler statistic. + if key == "step_meta": + continue data[key][draw_idx] = val elif self._stats is not None: raise ValueError("Expected sampler_stats") diff --git a/pymc/sampling/mcmc.py b/pymc/sampling/mcmc.py index f2dfa6e9c2..f12c0345a0 100644 --- a/pymc/sampling/mcmc.py +++ b/pymc/sampling/mcmc.py @@ -15,6 +15,7 @@ """Functions for MCMC sampling.""" import contextlib +import itertools import logging import pickle import sys @@ -111,6 +112,7 @@ def instantiate_steppers( step_kwargs: dict[str, dict] | None = None, initial_point: PointType | None = None, compile_kwargs: dict | None = None, + step_id_generator: Iterator[int] | None = None, ) -> Step | list[Step]: """Instantiate steppers assigned to the model variables. @@ -139,6 +141,9 @@ def instantiate_steppers( if step_kwargs is None: step_kwargs = {} + if step_id_generator is None: + step_id_generator = itertools.count() + used_keys = set() if selected_steps: if initial_point is None: @@ -154,6 +159,7 @@ def instantiate_steppers( model=model, initial_point=initial_point, compile_kwargs=compile_kwargs, + step_id_generator=step_id_generator, **kwargs, ) steps.append(step) @@ -853,6 +859,8 @@ def joined_blas_limiter(): initial_points = [ipfn(seed) for ipfn, seed in zip(ipfns, random_seed_list)] # Instantiate automatically selected steps + # Use a counter to generate a unique id for each stepper used in the model. + step_id_generator = itertools.count() step = instantiate_steppers( model, steps=provided_steps, @@ -860,9 +868,10 @@ def joined_blas_limiter(): step_kwargs=kwargs, initial_point=initial_points[0], compile_kwargs=compile_kwargs, + step_id_generator=step_id_generator, ) if isinstance(step, list): - step = CompoundStep(step) + step = CompoundStep(step, step_id_generator=step_id_generator) if var_names is not None: trace_vars = [v for v in model.unobserved_RVs if v.name in var_names] diff --git a/pymc/step_methods/arraystep.py b/pymc/step_methods/arraystep.py index 0c20e09a47..54785c2fe8 100644 --- a/pymc/step_methods/arraystep.py +++ b/pymc/step_methods/arraystep.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Iterator from typing import cast import numpy as np @@ -43,14 +43,25 @@ class ArrayStep(BlockedStep): :py:func:`pymc.util.get_random_generator` for more information. """ - def __init__(self, vars, fs, allvars=False, blocked=True, rng: RandomGenerator = None): + def __init__( + self, + vars, + fs, + allvars=False, + blocked=True, + rng: RandomGenerator = None, + step_id_generator: Iterator[int] | None = None, + ): self.vars = vars self.fs = fs self.allvars = allvars self.blocked = blocked self.rng = get_random_generator(rng) + self._step_id = next(step_id_generator) if step_id_generator else None - def step(self, point: PointType) -> tuple[PointType, StatsType]: + def step( + self, point: PointType, step_parent_id: int | None = None + ) -> tuple[PointType, StatsType]: partial_funcs_and_point: list[Callable | PointType] = [ DictToArrayBijection.mapf(x, start_point=point) for x in self.fs ] @@ -61,6 +72,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]: apoint = DictToArrayBijection.map(var_dict) apoint_new, stats = self.astep(apoint, *partial_funcs_and_point) + for sts in stats: + sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id} + if not isinstance(apoint_new, RaveledVars): # We assume that the mapping has stayed the same apoint_new = RaveledVars(apoint_new, apoint.point_map_info) @@ -84,7 +98,14 @@ class ArrayStepShared(BlockedStep): and unmapping overhead as well as moving fewer variables around. """ - def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): + def __init__( + self, + vars, + shared, + blocked=True, + rng: RandomGenerator = None, + step_id_generator: Iterator[int] | None = None, + ): """ Create the ArrayStepShared object. @@ -103,8 +124,11 @@ def __init__(self, vars, shared, blocked=True, rng: RandomGenerator = None): self.shared = {get_var_name(var): shared for var, shared in shared.items()} self.blocked = blocked self.rng = get_random_generator(rng) + self._step_id = next(step_id_generator) if step_id_generator else None - def step(self, point: PointType) -> tuple[PointType, StatsType]: + def step( + self, point: PointType, step_parent_id: int | None = None + ) -> tuple[PointType, StatsType]: full_point = None if self.shared: for name, shared_var in self.shared.items(): @@ -115,6 +139,9 @@ def step(self, point: PointType) -> tuple[PointType, StatsType]: q = DictToArrayBijection.map(point) apoint, stats = self.astep(q) + for sts in stats: + sts["step_meta"] = {"step_id": self._step_id, "step_parent_id": step_parent_id} + if not isinstance(apoint, RaveledVars): # We assume that the mapping has stayed the same apoint = RaveledVars(apoint, q.point_map_info) diff --git a/pymc/step_methods/compound.py b/pymc/step_methods/compound.py index d07b070f0f..d522fa9961 100644 --- a/pymc/step_methods/compound.py +++ b/pymc/step_methods/compound.py @@ -21,7 +21,7 @@ import warnings from abc import ABC, abstractmethod -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence from dataclasses import field from enum import IntEnum, unique from typing import Any @@ -29,6 +29,7 @@ import numpy as np from pytensor.graph.basic import Variable +from rich.progress import ProgressColumn from pymc.blocking import PointType, StatDtype, StatsDict, StatShape, StatsType from pymc.model import modelcontext @@ -124,6 +125,8 @@ class BlockedStep(ABC, WithSamplingState): def __new__(cls, *args, **kwargs): blocked = kwargs.get("blocked") + step_id_generator = kwargs.pop("step_id_generator", None) + if blocked is None: # Try to look up default value from class blocked = getattr(cls, "default_blocked", True) @@ -167,31 +170,89 @@ def __new__(cls, *args, **kwargs): # call __init__ _kwargs = kwargs.copy() _kwargs["rng"] = rng + _kwargs["step_id_generator"] = step_id_generator step.__init__([var], *args, **_kwargs) # Hack for creating the class correctly when unpickling. step.__newargs = ([var], *args), _kwargs steps.append(step) - return CompoundStep(steps) + return CompoundStep(steps, step_id_generator=step_id_generator) else: step = super().__new__(cls) step.stats_dtypes = stats_dtypes step.stats_dtypes_shapes = stats_dtypes_shapes + step._step_id = next(step_id_generator) if step_id_generator else None + # Hack for creating the class correctly when unpickling. step.__newargs = (vars, *args), kwargs return step - @staticmethod - def _progressbar_config(n_chains=1): + def _progressbar_config(self, n_chains: int = 1): + """ + Get progressbar configuration for this step sampler. + + By default, the progress bar displays no stats columns, only basic info (number of draws and sampling time). + Specific step methods should overload this method to specify which stats to display and how. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ columns = [] stats = {} return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - return stats + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + By default, the update is a no-op. Specific step methods should implement special logic for which + statistics to display and how. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats_dict: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + return displayed_stats return update_stats @@ -200,7 +261,9 @@ def __getnewargs_ex__(self): return self.__newargs @abstractmethod - def step(self, point: PointType) -> tuple[PointType, StatsType]: + def step( + self, point: PointType, step_parent_id: int | None = None + ) -> tuple[PointType, StatsType]: """Perform a single step of the sampler.""" @staticmethod @@ -259,7 +322,7 @@ class CompoundStep(WithSamplingState): _state_class = CompoundStepState - def __init__(self, methods): + def __init__(self, methods, step_id_generator: Iterator[int] | None = None): self.methods = list(methods) self.stats_dtypes = [] for method in self.methods: @@ -269,11 +332,12 @@ def __init__(self, methods): f"Compound[{', '.join(getattr(m, 'name', 'UNNAMED_STEP') for m in self.methods)}]" ) self.tune = True + self._step_id = next(step_id_generator) if step_id_generator else None - def step(self, point) -> tuple[PointType, StatsType]: + def step(self, point, step_parent_id: int | None = None) -> tuple[PointType, StatsType]: stats = [] for method in self.methods: - point, sts = method.step(point) + point, sts = method.step(point, step_parent_id=self._step_id) stats.extend(sts) # Model logp can only be the logp of the _last_ stats, # if there is one. Pop all others. @@ -311,7 +375,28 @@ def set_rng(self, rng: RandomGenerator): for method, _rng in zip(self.methods, _rngs): method.set_rng(_rng) - def _progressbar_config(self, n_chains=1): + def _progressbar_config( + self, n_chains: int = 1 + ) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]: + """ + Get progressbar configuration for this step sampler. + + The columns of the rich progress bar displayed during sampler are chosen by the step samplers themselves. In + the compound step case, we display the set union of all columns from the sub-step samplers. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ from functools import reduce column_lists, stat_dict_list = zip( @@ -332,14 +417,56 @@ def _progressbar_config(self, n_chains=1): return columns, stats - def _make_update_stats_function(self): - update_fns = [method._make_update_stats_function() for method in self.methods] + def _make_update_stats_function(self) -> Callable[[dict, dict[int, dict], int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. - def update_stats(stats, step_stats, chain_idx): - for step_stat, update_fn in zip(step_stats, update_fns): - stats = update_fn(stats, step_stat, chain_idx) + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + update_fns = { + method._step_id: method._make_update_stats_function() for method in self.methods + } - return stats + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats_dict: dict of dict + List of dictionaries containing statistics generated by **each** step sampler in the CompoundStep when + taking the current step. For each dictionary, the keys are names of statistics and the values are + the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + # TODO: The compound step is commonly made of many instances of the same step (e.g. 3 Metropolis steps). + # In this case, the current loop logic is just overriding each Metropolis steps' stats with those of the + # next step (so the user only ever sees the 3rd step's stats). We should have a better way to aggregate + # the stats from each step. + + for step_id, update_fn in update_fns.items(): + displayed_stats = update_fn(displayed_stats, step_stats_dict, chain_idx) + + return displayed_stats return update_stats diff --git a/pymc/step_methods/hmc/base_hmc.py b/pymc/step_methods/hmc/base_hmc.py index e8c96e8c4b..ccce50e9f5 100644 --- a/pymc/step_methods/hmc/base_hmc.py +++ b/pymc/step_methods/hmc/base_hmc.py @@ -18,6 +18,7 @@ import time from abc import abstractmethod +from collections.abc import Iterator from typing import Any, NamedTuple import numpy as np @@ -99,6 +100,7 @@ def __init__( step_rand=None, rng=None, initial_point: PointType | None = None, + step_id_generator: Iterator[int] | None = None, **pytensor_kwargs, ): """Set up Hamiltonian samplers with common structures. @@ -133,6 +135,7 @@ def __init__( **pytensor_kwargs: passed to PyTensor functions """ self._model = modelcontext(model) + self._step_id = next(step_id_generator) if step_id_generator else None if vars is None: vars = self._model.continuous_value_vars diff --git a/pymc/step_methods/hmc/nuts.py b/pymc/step_methods/hmc/nuts.py index 18707c3592..98ca981b8e 100644 --- a/pymc/step_methods/hmc/nuts.py +++ b/pymc/step_methods/hmc/nuts.py @@ -15,12 +15,13 @@ from __future__ import annotations from collections import namedtuple +from collections.abc import Callable from dataclasses import field import numpy as np from pytensor import config -from rich.progress import TextColumn +from rich.progress import ProgressColumn, TextColumn from rich.table import Column from pymc.stats.convergence import SamplerWarning @@ -231,8 +232,25 @@ def competence(var, has_grad): return Competence.PREFERRED return Competence.INCOMPATIBLE - @staticmethod - def _progressbar_config(n_chains=1): + def _progressbar_config( + self, n_chains: int = 1 + ) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]: + """ + Get progressbar configuration for this step sampler. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ columns = [ TextColumn("{task.fields[divergences]}", table_column=Column("Divergences", ratio=1)), TextColumn("{task.fields[step_size]:0.2f}", table_column=Column("Step size", ratio=1)), @@ -247,18 +265,51 @@ def _progressbar_config(n_chains=1): return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats_dict: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + step_stats = step_stats_dict[self._step_id] if not step_stats["tune"]: - stats["divergences"][chain_idx] += step_stats["diverging"] + displayed_stats["divergences"][chain_idx] += step_stats["diverging"] - stats["step_size"][chain_idx] = step_stats["step_size"] - stats["tree_size"][chain_idx] = step_stats["tree_size"] - return stats + displayed_stats["step_size"][chain_idx] = step_stats["step_size"] + displayed_stats["tree_size"][chain_idx] = step_stats["tree_size"] + return displayed_stats return update_stats diff --git a/pymc/step_methods/metropolis.py b/pymc/step_methods/metropolis.py index 70c650653d..e27ea83458 100644 --- a/pymc/step_methods/metropolis.py +++ b/pymc/step_methods/metropolis.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Callable +from collections.abc import Callable, Iterator from dataclasses import field from typing import Any @@ -24,7 +24,7 @@ from pytensor import tensor as pt from pytensor.graph.fg import MissingInputError from pytensor.tensor.random.basic import BernoulliRV, CategoricalRV -from rich.progress import TextColumn +from rich.progress import ProgressColumn, TextColumn from rich.table import Column import pymc as pm @@ -166,6 +166,7 @@ def __init__( initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = False, + step_id_generator: Iterator[int] | None = None, ): """Create an instance of a Metropolis stepper. @@ -258,7 +259,9 @@ def __init__( shared = pm.make_shared_replacements(initial_point, vars, model) self.delta_logp = delta_logp(initial_point, model.logp(), vars, shared, compile_kwargs) - super().__init__(vars, shared, blocked=blocked, rng=rng) + super().__init__( + vars, shared, blocked=blocked, rng=rng, step_id_generator=step_id_generator + ) def reset_tuning(self): """Reset the tuned sampler parameters to their initial values.""" @@ -327,8 +330,25 @@ def astep(self, q0: RaveledVars) -> tuple[RaveledVars, StatsType]: def competence(var, has_grad): return Competence.COMPATIBLE - @staticmethod - def _progressbar_config(n_chains=1): + def _progressbar_config( + self, n_chains: int = 1 + ) -> tuple[list[ProgressColumn], dict[str, np.ndarray | float]]: + """ + Get progressbar configuration for this step sampler. + + Parameters + ---------- + n_chains: int + Number of chains being sampled. This controls the number of progress bars that will be displayed. + + Returns + ------- + columns: list of rich.progress.ProgressColumn + List of columns to display in the progress bar. + + stats: dict + Dictionary of statistics associated with each column. + """ columns = [ TextColumn("{task.fields[tune]}", table_column=Column("Tuning", ratio=1)), TextColumn("{task.fields[scaling]:0.2f}", table_column=Column("Scaling", ratio=1)), @@ -345,17 +365,50 @@ def _progressbar_config(n_chains=1): return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats_dict: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + step_stats = step_stats_dict[self._step_id] - stats["tune"][chain_idx] = step_stats["tune"] - stats["accept_rate"][chain_idx] = step_stats["accept"] - stats["scaling"][chain_idx] = step_stats["scaling"] + displayed_stats["tune"][chain_idx] = step_stats["tune"] + displayed_stats["accept_rate"][chain_idx] = step_stats["accept"] + displayed_stats["scaling"][chain_idx] = step_stats["scaling"] - return stats + return displayed_stats return update_stats @@ -951,7 +1004,9 @@ def __init__( initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = True, + step_id_generator: Iterator[int] | None = None, ): + self._step_id = next(step_id_generator) if step_id_generator else None model = pm.modelcontext(model) if initial_point is None: initial_point = model.initial_point() diff --git a/pymc/step_methods/slicer.py b/pymc/step_methods/slicer.py index 9c10acfdf4..8350e6b767 100644 --- a/pymc/step_methods/slicer.py +++ b/pymc/step_methods/slicer.py @@ -11,10 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable, Iterator # Modified from original implementation by Dominik Wabersich (2013) - - import numpy as np from rich.progress import TextColumn @@ -35,9 +34,6 @@ LOOP_ERR_MSG = "max slicer iters %d exceeded" -dataclass_state - - @dataclass_state class SliceState(StepMethodState): w: np.ndarray @@ -91,7 +87,9 @@ def __init__( initial_point: PointType | None = None, compile_kwargs: dict | None = None, blocked: bool = False, # Could be true since tuning is independent across dims? + step_id_generator: Iterator[int] | None = None, ): + self._step_id = next(step_id_generator) if step_id_generator else None model = modelcontext(model) self.w = np.asarray(w).copy() self.tune = tune @@ -211,16 +209,49 @@ def _progressbar_config(n_chains=1): return columns, stats - @staticmethod - def _make_update_stats_function(): - def update_stats(stats, step_stats, chain_idx): - if isinstance(step_stats, list): - step_stats = step_stats[0] - - stats["tune"][chain_idx] = step_stats["tune"] - stats["nstep_out"][chain_idx] = step_stats["nstep_out"] - stats["nstep_in"][chain_idx] = step_stats["nstep_in"] - - return stats + def _make_update_stats_function(self) -> Callable[[dict, dict, int], dict]: + """ + Create an update function used by the progress bar to update statistics during sampling. + + Returns + ------- + update_stats: Callable + Function that updates displayed statistics for the current chain, given statistics generated by the step + during the most recent step. + """ + + def update_stats( + displayed_stats: dict[str, np.ndarray], + step_stats_dict: dict[int, dict[str, str | float | int | bool | None]], + chain_idx: int, + ) -> dict[str, np.ndarray]: + """ + Update the statistics displayed in the progress bar after each step. + + Parameters + ---------- + displayed_stats: dict + Dictionary of statistics displayed in the progress bar. The keys are the names of the statistics and + the values are the current values of the statistics, with one value per chain being sampled. + + step_stats_dict: dict + Dictionary of statistics generated by the step sampler when taking the current step. The keys are the + names of the statistics and the values are the values of the statistics generated by the step sampler. + + chain_idx: int + The chain number associated with the current step + + Returns + ------- + dict + The updated statistics dictionary to be displayed in the progress bar. + """ + step_stats = step_stats_dict[self._step_id] + + displayed_stats["tune"][chain_idx] = step_stats["tune"] + displayed_stats["nstep_out"][chain_idx] = step_stats["nstep_out"] + displayed_stats["nstep_in"][chain_idx] = step_stats["nstep_in"] + + return displayed_stats return update_stats diff --git a/pymc/util.py b/pymc/util.py index 979b3beebf..586d567204 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -883,7 +883,10 @@ def update(self, chain_idx, is_last, draw, tuning, stats): if not tuning and stats and stats[0].get("diverging"): self.divergences += 1 - self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx) + step_meta = [entry["step_meta"] for entry in stats] + step_id_to_stats = {meta["step_id"]: entry for meta, entry in zip(step_meta, stats)} + + self.progress_stats = self.update_stats(self.progress_stats, step_id_to_stats, chain_idx) more_updates = ( {stat: value[chain_idx] for stat, value in self.progress_stats.items()} if self.full_stats