Skip to content

ADVI errors in numba mode for StudentT likelihood when total_size is set #7778

@jessegrabowski

Description

@jessegrabowski

Description

import pymc as pm
import numpy as np

rng = np.random.default_rng()

with pm.Model() as m:
    data = pm.Data('data', rng.normal(size=(1000, 5)))
    obs = pm.Data('obs', rng.normal(size=(1000,)))
    
    data_batch, obs_batch = pm.Minibatch(data, obs, batch_size=128)
    
    beta = pm.Normal('beta', size=(5,))
    mu = data_batch @ beta
    sigma = pm.Exponential('sigma', 1)
    
    y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
    
    idata = pm.fit(n=1_000_000, 
                   compile_kwargs={'mode':'NUMBA'})
Full Traceback
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[1], line 18
     14 sigma = pm.Exponential('sigma', 1)
     16 y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
---> 18 idata = pm.fit(n=1_000_000, 
     19                compile_kwargs={'mode':'NUMBA'})

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:775, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    773 else:
    774     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 775 return inference.fit(n, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:158, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs)
    156     callbacks = []
    157 score = self._maybe_score(score)
--> 158 step_func = self.objective.step_function(score=score, **kwargs)
    160 if score:
    161     state = self._iterate_with_loss(
    162         0, n, step_func, progressbar, progressbar_theme, callbacks
    163     )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/opvi.py:405, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, compile_kwargs, fn_kwargs)
    403 seed = self.approx.rng.randint(2**30, dtype=np.int64)
    404 if score:
--> 405     step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
    406 else:
    407     step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs)
    945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
    946 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 947 pytensor_function = pytensor.function(
    948     inputs,
    949     outputs,
    950     updates={**rng_updates, **kwargs.pop("updates", {})},
    951     mode=mode,
    952     **kwargs,
    953 )
    954 return pytensor_function

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
    321     fn = orig_function(
    322         inputs,
    323         outputs,
   (...)    327         trust_input=trust_input,
    328     )
    329 else:
    330     # note: pfunc will also call orig_function -- orig_function is
    331     #      a choke point that all compilation must pass through
--> 332     fn = pfunc(
    333         params=inputs,
    334         outputs=outputs,
    335         mode=mode,
    336         updates=updates,
    337         givens=givens,
    338         no_default_updates=no_default_updates,
    339         accept_inplace=accept_inplace,
    340         name=name,
    341         rebuild_strict=rebuild_strict,
    342         allow_input_downcast=allow_input_downcast,
    343         on_unused_input=on_unused_input,
    344         profile=profile,
    345         output_keys=output_keys,
    346         trust_input=trust_input,
    347     )
    348 return fn

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
    452     profile = ProfileStats(message=profile)
    454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
    455     params,
    456     outputs,
   (...)    463     fgraph=fgraph,
    464 )
--> 466 return orig_function(
    467     inputs,
    468     cloned_outputs,
    469     mode,
    470     accept_inplace=accept_inplace,
    471     name=name,
    472     profile=profile,
    473     on_unused_input=on_unused_input,
    474     output_keys=output_keys,
    475     fgraph=fgraph,
    476     trust_input=trust_input,
    477 )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
   1820     m = Maker(
   1821         inputs,
   1822         outputs,
   (...)   1830         trust_input=trust_input,
   1831     )
   1832     with config.change_flags(compute_test_value="off"):
-> 1833         fn = m.create(defaults)
   1834 finally:
   1835     if profile and fn:

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
   1714 start_import_time = pytensor.link.c.cmodule.import_time
   1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717     _fn, _i, _o = self.linker.make_thunk(
   1718         input_storage=input_storage_lists, storage_map=storage_map
   1719     )
   1721 end_linker = time.perf_counter()
   1723 linker_time = end_linker - start_linker

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
    238 def make_thunk(
    239     self,
    240     input_storage: Optional["InputStorageType"] = None,
   (...)    243     **kwargs,
    244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245     return self.make_all(
    246         input_storage=input_storage,
    247         output_storage=output_storage,
    248         storage_map=storage_map,
    249     )[:3]

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
    692 for k in storage_map:
    693     compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
    696     compute_map, nodes, input_storage, output_storage, storage_map
    697 )
    699 [fn] = thunks
    700 fn.jit_fn = jit_fn

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
    644 # This is a bit hackish, but we only return one of the output nodes
    645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
    648     self.fgraph,
    649     order=order,
    650     input_storage=input_storage,
    651     output_storage=output_storage,
    652     storage_map=storage_map,
    653 )
    655 thunk_inputs = self.create_thunk_inputs(storage_map)
    656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
      7 def fgraph_convert(self, fgraph, **kwargs):
      8     from pytensor.link.numba.dispatch import numba_funcify
---> 10     return numba_funcify(fgraph, **kwargs)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:380, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    373 @numba_funcify.register(FunctionGraph)
    374 def numba_funcify_FunctionGraph(
    375     fgraph,
   (...)    378     **kwargs,
    379 ):
--> 380     return fgraph_to_python(
    381         fgraph,
    382         numba_funcify,
    383         type_conversion_fn=numba_typify,
    384         fgraph_name=fgraph_name,
    385         **kwargs,
    386     )

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
    734 body_assigns = []
    735 for node in order:
--> 736     compiled_func = op_conversion_fn(
    737         node.op, node=node, storage_map=storage_map, **kwargs
    738     )
    740     # Create a local alias with a unique name
    741     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:401, in numba_funcify_RandomVariable(op, node, **kwargs)
    398 core_shape_len = get_vector_length(core_shape)
    399 inplace = rv_op.inplace
--> 401 core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
    402 nin = 1 + len(dist_params)  # rng + params
    403 core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)

File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
    908 if not args:
    909     raise TypeError(f'{funcname} requires at least '
    910                     '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:47, in numba_core_rv_funcify(op, node)
     44 @singledispatch
     45 def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
     46     """Return the core function for a random variable operation."""
---> 47     raise NotImplementedError(f"Core implementation of {op} not implemented.")

NotImplementedError: Core implementation of t_rv{"(),(),()->()"} not implemented.

Interestingly, it works fine if you change total_size = None.

Metadata

Metadata

Assignees

No one assigned

    Labels

    VIVariational Inferencebugnumba

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      pFad - Phonifier reborn

      Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

      Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


      Alternative Proxies:

      Alternative Proxy

      pFad Proxy

      pFad v3 Proxy

      pFad v4 Proxy