-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
pymc-devs/pytensor
#1402Labels
Description
Description
import pymc as pm
import numpy as np
rng = np.random.default_rng()
with pm.Model() as m:
data = pm.Data('data', rng.normal(size=(1000, 5)))
obs = pm.Data('obs', rng.normal(size=(1000,)))
data_batch, obs_batch = pm.Minibatch(data, obs, batch_size=128)
beta = pm.Normal('beta', size=(5,))
mu = data_batch @ beta
sigma = pm.Exponential('sigma', 1)
y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
idata = pm.fit(n=1_000_000,
compile_kwargs={'mode':'NUMBA'})
Full Traceback
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[1], line 18
14 sigma = pm.Exponential('sigma', 1)
16 y_hat = pm.StudentT('y_hat', mu=mu, sigma=sigma, nu=3, observed=obs_batch, total_size=1000)
---> 18 idata = pm.fit(n=1_000_000,
19 compile_kwargs={'mode':'NUMBA'})
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:775, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
773 else:
774 raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 775 return inference.fit(n, **kwargs)
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/inference.py:158, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs)
156 callbacks = []
157 score = self._maybe_score(score)
--> 158 step_func = self.objective.step_function(score=score, **kwargs)
160 if score:
161 state = self._iterate_with_loss(
162 0, n, step_func, progressbar, progressbar_theme, callbacks
163 )
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/variational/opvi.py:405, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, compile_kwargs, fn_kwargs)
403 seed = self.approx.rng.randint(2**30, dtype=np.int64)
404 if score:
--> 405 step_fn = compile([], updates.loss, updates=updates, random_seed=seed, **compile_kwargs)
406 else:
407 step_fn = compile([], [], updates=updates, random_seed=seed, **compile_kwargs)
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pymc/pytensorf.py:947, in compile(inputs, outputs, random_seed, mode, **kwargs)
945 opt_qry = mode.provided_optimizer.including("random_make_inplace", check_parameter_opt)
946 mode = Mode(linker=mode.linker, optimizer=opt_qry)
--> 947 pytensor_function = pytensor.function(
948 inputs,
949 outputs,
950 updates={**rng_updates, **kwargs.pop("updates", {})},
951 mode=mode,
952 **kwargs,
953 )
954 return pytensor_function
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/__init__.py:332, in function(inputs, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, trust_input)
321 fn = orig_function(
322 inputs,
323 outputs,
(...) 327 trust_input=trust_input,
328 )
329 else:
330 # note: pfunc will also call orig_function -- orig_function is
331 # a choke point that all compilation must pass through
--> 332 fn = pfunc(
333 params=inputs,
334 outputs=outputs,
335 mode=mode,
336 updates=updates,
337 givens=givens,
338 no_default_updates=no_default_updates,
339 accept_inplace=accept_inplace,
340 name=name,
341 rebuild_strict=rebuild_strict,
342 allow_input_downcast=allow_input_downcast,
343 on_unused_input=on_unused_input,
344 profile=profile,
345 output_keys=output_keys,
346 trust_input=trust_input,
347 )
348 return fn
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/pfunc.py:466, in pfunc(params, outputs, mode, updates, givens, no_default_updates, accept_inplace, name, rebuild_strict, allow_input_downcast, profile, on_unused_input, output_keys, fgraph, trust_input)
452 profile = ProfileStats(message=profile)
454 inputs, cloned_outputs = construct_pfunc_ins_and_outs(
455 params,
456 outputs,
(...) 463 fgraph=fgraph,
464 )
--> 466 return orig_function(
467 inputs,
468 cloned_outputs,
469 mode,
470 accept_inplace=accept_inplace,
471 name=name,
472 profile=profile,
473 on_unused_input=on_unused_input,
474 output_keys=output_keys,
475 fgraph=fgraph,
476 trust_input=trust_input,
477 )
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1833, in orig_function(inputs, outputs, mode, accept_inplace, name, profile, on_unused_input, output_keys, fgraph, trust_input)
1820 m = Maker(
1821 inputs,
1822 outputs,
(...) 1830 trust_input=trust_input,
1831 )
1832 with config.change_flags(compute_test_value="off"):
-> 1833 fn = m.create(defaults)
1834 finally:
1835 if profile and fn:
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/compile/function/types.py:1717, in FunctionMaker.create(self, input_storage, storage_map)
1714 start_import_time = pytensor.link.c.cmodule.import_time
1716 with config.change_flags(traceback__limit=config.traceback__compile_limit):
-> 1717 _fn, _i, _o = self.linker.make_thunk(
1718 input_storage=input_storage_lists, storage_map=storage_map
1719 )
1721 end_linker = time.perf_counter()
1723 linker_time = end_linker - start_linker
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:245, in LocalLinker.make_thunk(self, input_storage, output_storage, storage_map, **kwargs)
238 def make_thunk(
239 self,
240 input_storage: Optional["InputStorageType"] = None,
(...) 243 **kwargs,
244 ) -> tuple["BasicThunkType", "InputStorageType", "OutputStorageType"]:
--> 245 return self.make_all(
246 input_storage=input_storage,
247 output_storage=output_storage,
248 storage_map=storage_map,
249 )[:3]
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:695, in JITLinker.make_all(self, input_storage, output_storage, storage_map)
692 for k in storage_map:
693 compute_map[k] = [k.owner is None]
--> 695 thunks, nodes, jit_fn = self.create_jitable_thunk(
696 compute_map, nodes, input_storage, output_storage, storage_map
697 )
699 [fn] = thunks
700 fn.jit_fn = jit_fn
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/basic.py:647, in JITLinker.create_jitable_thunk(self, compute_map, order, input_storage, output_storage, storage_map)
644 # This is a bit hackish, but we only return one of the output nodes
645 output_nodes = [o.owner for o in self.fgraph.outputs if o.owner is not None][:1]
--> 647 converted_fgraph = self.fgraph_convert(
648 self.fgraph,
649 order=order,
650 input_storage=input_storage,
651 output_storage=output_storage,
652 storage_map=storage_map,
653 )
655 thunk_inputs = self.create_thunk_inputs(storage_map)
656 thunk_outputs = [storage_map[n] for n in self.fgraph.outputs]
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/linker.py:10, in NumbaLinker.fgraph_convert(self, fgraph, **kwargs)
7 def fgraph_convert(self, fgraph, **kwargs):
8 from pytensor.link.numba.dispatch import numba_funcify
---> 10 return numba_funcify(fgraph, **kwargs)
File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/basic.py:380, in numba_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
373 @numba_funcify.register(FunctionGraph)
374 def numba_funcify_FunctionGraph(
375 fgraph,
(...) 378 **kwargs,
379 ):
--> 380 return fgraph_to_python(
381 fgraph,
382 numba_funcify,
383 type_conversion_fn=numba_typify,
384 fgraph_name=fgraph_name,
385 **kwargs,
386 )
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/utils.py:736, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, unique_name, **kwargs)
734 body_assigns = []
735 for node in order:
--> 736 compiled_func = op_conversion_fn(
737 node.op, node=node, storage_map=storage_map, **kwargs
738 )
740 # Create a local alias with a unique name
741 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:401, in numba_funcify_RandomVariable(op, node, **kwargs)
398 core_shape_len = get_vector_length(core_shape)
399 inplace = rv_op.inplace
--> 401 core_rv_fn = numba_core_rv_funcify(rv_op, rv_node)
402 nin = 1 + len(dist_params) # rng + params
403 core_op_fn = store_core_outputs(core_rv_fn, nin=nin, nout=1)
File ~/mambaforge/envs/econ/lib/python3.12/functools.py:912, in singledispatch.<locals>.wrapper(*args, **kw)
908 if not args:
909 raise TypeError(f'{funcname} requires at least '
910 '1 positional argument')
--> 912 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/econ/lib/python3.12/site-packages/pytensor/link/numba/dispatch/random.py:47, in numba_core_rv_funcify(op, node)
44 @singledispatch
45 def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
46 """Return the core function for a random variable operation."""
---> 47 raise NotImplementedError(f"Core implementation of {op} not implemented.")
NotImplementedError: Core implementation of t_rv{"(),(),()->()"} not implemented.
Interestingly, it works fine if you change total_size = None
.