Skip to content

BUG: Progress bar throws error when nested CompoundSteps are present. #7721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
fruzti opened this issue Mar 11, 2025 · 1 comment · May be fixed by #7730
Open

BUG: Progress bar throws error when nested CompoundSteps are present. #7721

fruzti opened this issue Mar 11, 2025 · 1 comment · May be fixed by #7730
Labels

Comments

@fruzti
Copy link

fruzti commented Mar 11, 2025

Describe the issue:

Progress bar throws an error when a nested CompoundStep is found in the sampling flow of a model. Once the progress bar is deactivated, i..e, progressbar=False, the error is not anymore present.

Reproduceable code example:

with pm.Model() as modeWithErros:

    a   = pm.Poisson("a",mu=10)

    b   = pm.Binomial("b", n=a, p=0.8)

    c   = pm.Poisson("c",mu=11)

    d   = pm.Dirichlet("d",a=pt.stack([c,b]))

    pm.sample(draws=1000,tune=1000,chains=4)

Error message:

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>CompoundStep
>>Metropolis: [a]
>>Metropolis: [b]
>>Metropolis: [c]
>NUTS: [d]

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[250], line 11
      7 c   = pm.Poisson("c",mu=11)
      9 d   = pm.Dirichlet("d",a=pt.stack([c,b]))
---> 11 pm.sample(draws=1000,tune=1000,chains=4)

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\mcmc.py:935, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    933 _print_step_hierarchy(step)
    934 try:
--> 935     _mp_sample(**sample_args, **parallel_args)
    936 except pickle.PickleError:
    937     _log.warning("Could not pickle model, sampling singlethreaded.")

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\mcmc.py:1411, in _mp_sample(draws, tune, step, chains, cores, rngs, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
   1409 try:
   1410     with sampler:
-> 1411         for draw in sampler:
   1412             strace = traces[draw.chain]
   1413             if not zarr_recording:
   1414                 # Zarr recording happens in each process

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\parallel.py:513, in ParallelSampler.__iter__(self)
    510 draw = ProcessAdapter.recv_draw(self._active)
    511 proc, is_last, draw, tuning, stats = draw
--> 513 self._progress.update(
    514     chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
    515 )
    517 if is_last:
    518     proc.join()

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\util.py:886, in ProgressBarManager.update(self, chain_idx, is_last, draw, tuning, stats)
    883 if not tuning and stats and stats[0].get("diverging"):
    884     self.divergences += 1
--> 886 self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
    887 more_updates = (
    888     {stat: value[chain_idx] for stat, value in self.progress_stats.items()}
    889     if self.full_stats
    890     else {}
    891 )
    893 self._progress.update(
    894     self.tasks[chain_idx],
    895     completed=draw,
   (...)
    899     **more_updates,
    900 )

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\metropolis.py:354, in Metropolis._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    351 if isinstance(step_stats, list):
    352     step_stats = step_stats[0]
--> 354 stats["tune"][chain_idx] = step_stats["tune"]
    355 stats["accept_rate"][chain_idx] = step_stats["accept"]
    356 stats["scaling"][chain_idx] = step_stats["scaling"]

TypeError: string indices must be integers, not 'str'

PyMC version information:

pymc 5.21.0

Context for the issue:

Given that it is only when the progressbar is active, it is seems to not be urgent. Also, people should be able to find the workaround here.

@fruzti fruzti added the bug label Mar 11, 2025
Copy link

welcome bot commented Mar 11, 2025

Welcome Banner]
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant