From 4968cb3ed4b76a147439b20c6f58e36c6e2b4839 Mon Sep 17 00:00:00 2001 From: Richard Murray Date: Fri, 18 Nov 2022 22:00:22 -0800 Subject: [PATCH] check for and fix mutable keyword args --- control/flatsys/flatsys.py | 12 +++--- control/flatsys/linflat.py | 4 +- control/iosys.py | 5 ++- control/tests/kwargs_test.py | 71 ++++++++++++++++++++++++++++++++++-- 4 files changed, 78 insertions(+), 14 deletions(-) diff --git a/control/flatsys/flatsys.py b/control/flatsys/flatsys.py index 849c41c72..e0023c4de 100644 --- a/control/flatsys/flatsys.py +++ b/control/flatsys/flatsys.py @@ -142,7 +142,7 @@ def __init__(self, forward, reverse, # flat system updfcn=None, outfcn=None, # I/O system inputs=None, outputs=None, - states=None, params={}, dt=None, name=None): + states=None, params=None, dt=None, name=None): """Create a differentially flat I/O system. The FlatIOSystem constructor is used to create an input/output system @@ -171,7 +171,7 @@ def __str__(self): + f"Forward: {self.forward}\n" \ + f"Reverse: {self.reverse}" - def forward(self, x, u, params={}): + def forward(self, x, u, params=None): """Compute the flat flag given the states and input. @@ -200,7 +200,7 @@ def forward(self, x, u, params={}): """ raise NotImplementedError("internal error; forward method not defined") - def reverse(self, zflag, params={}): + def reverse(self, zflag, params=None): """Compute the states and input given the flat flag. Parameters @@ -224,18 +224,18 @@ def reverse(self, zflag, params={}): """ raise NotImplementedError("internal error; reverse method not defined") - def _flat_updfcn(self, t, x, u, params={}): + def _flat_updfcn(self, t, x, u, params=None): # TODO: implement state space update using flat coordinates raise NotImplementedError("update function for flat system not given") - def _flat_outfcn(self, t, x, u, params={}): + def _flat_outfcn(self, t, x, u, params=None): # Return the flat output zflag = self.forward(x, u, params) return np.array([zflag[i][0] for i in range(len(zflag))]) # Utility function to compute flag matrix given a basis -def _basis_flag_matrix(sys, basis, flag, t, params={}): +def _basis_flag_matrix(sys, basis, flag, t): """Compute the matrix of basis functions and their derivatives This function computes the matrix ``M`` that is used to solve for the diff --git a/control/flatsys/linflat.py b/control/flatsys/linflat.py index e4a31c6de..8e6c23604 100644 --- a/control/flatsys/linflat.py +++ b/control/flatsys/linflat.py @@ -142,11 +142,11 @@ def reverse(self, zflag, params): return np.reshape(x, self.nstates), np.reshape(u, self.ninputs) # Update function - def _rhs(self, t, x, u, params={}): + def _rhs(self, t, x, u): # Use LinearIOSystem._rhs instead of default (MRO) NonlinearIOSystem return LinearIOSystem._rhs(self, t, x, u) # output function - def _out(self, t, x, u, params={}): + def _out(self, t, x, u): # Use LinearIOSystem._out instead of default (MRO) NonlinearIOSystem return LinearIOSystem._out(self, t, x, u) diff --git a/control/iosys.py b/control/iosys.py index df75f3b54..6fa4a3e76 100644 --- a/control/iosys.py +++ b/control/iosys.py @@ -1584,7 +1584,7 @@ def __init__(self, io_sys, ss_sys=None): def input_output_response( sys, T, U=0., X0=0, params=None, transpose=False, return_x=False, squeeze=None, - solve_ivp_kwargs={}, t_eval='T', **kwargs): + solve_ivp_kwargs=None, t_eval='T', **kwargs): """Compute the output response of a system to a given input. Simulate a dynamical system with a given input and return its output @@ -1650,7 +1650,7 @@ def input_output_response( solve_ivp_method : str, optional Set the method used by :func:`scipy.integrate.solve_ivp`. Defaults to 'RK45'. - solve_ivp_kwargs : str, optional + solve_ivp_kwargs : dict, optional Pass additional keywords to :func:`scipy.integrate.solve_ivp`. Raises @@ -1676,6 +1676,7 @@ def input_output_response( # # Figure out the method to be used + solve_ivp_kwargs = solve_ivp_kwargs.copy() if solve_ivp_kwargs else {} if kwargs.get('solve_ivp_method', None): if kwargs.get('method', None): raise ValueError("ivp_method specified more than once") diff --git a/control/tests/kwargs_test.py b/control/tests/kwargs_test.py index 855bb9dda..2dc7f0563 100644 --- a/control/tests/kwargs_test.py +++ b/control/tests/kwargs_test.py @@ -38,6 +38,10 @@ def test_kwarg_search(module, prefix): # Skip anything that isn't part of the control package continue + # Look for classes and then check member functions + if inspect.isclass(obj): + test_kwarg_search(obj, prefix + obj.__name__ + '.') + # Only look for functions with keyword arguments if not inspect.isfunction(obj): continue @@ -70,10 +74,6 @@ def test_kwarg_search(module, prefix): f"'unrecognized keyword' not found in unit test " f"for {name}") - # Look for classes and then check member functions - if inspect.isclass(obj): - test_kwarg_search(obj, prefix + obj.__name__ + '.') - @pytest.mark.parametrize( "function, nsssys, ntfsys, moreargs, kwargs", @@ -201,3 +201,66 @@ def test_matplotlib_kwargs(function, nsysargs, moreargs, kwargs, mplcleanup): 'TimeResponseData.__call__': trdata_test.test_response_copy, 'TransferFunction.__init__': test_unrecognized_kwargs, } + +# +# Look for keywords with mutable defaults +# +# This test goes through every function and looks for signatures that have a +# default value for a keyword that is mutable. An error is generated unless +# the function is listed in the `mutable_ok` set (which should only be used +# for cases were the code has been explicitly checked to make sure that the +# value of the mutable is not modified in the code). +# +mutable_ok = { # initial and date + control.flatsys.SystemTrajectory.__init__, # RMM, 18 Nov 2022 + control.freqplot._add_arrows_to_line2D, # RMM, 18 Nov 2022 + control.namedio._process_dt_keyword, # RMM, 13 Nov 2022 + control.namedio._process_namedio_keywords, # RMM, 18 Nov 2022 + control.optimal.OptimalControlProblem.__init__, # RMM, 18 Nov 2022 + control.optimal.solve_ocp, # RMM, 18 Nov 2022 + control.optimal.create_mpc_iosystem, # RMM, 18 Nov 2022 +} + +@pytest.mark.parametrize("module", [control, control.flatsys]) +def test_mutable_defaults(module, recurse=True): + # Look through every object in the package + for name, obj in inspect.getmembers(module): + # Skip anything that is outside of this module + if inspect.getmodule(obj) is not None and \ + not inspect.getmodule(obj).__name__.startswith('control'): + # Skip anything that isn't part of the control package + continue + + # Look for classes and then check member functions + if inspect.isclass(obj): + test_mutable_defaults(obj, True) + + # Look for modules and check for internal functions (w/ no recursion) + if inspect.ismodule(obj) and recurse: + test_mutable_defaults(obj, False) + + # Only look at functions and skip any that are marked as OK + if not inspect.isfunction(obj) or obj in mutable_ok: + continue + + # Get the signature for the function + sig = inspect.signature(obj) + + # Skip anything that is inherited + if inspect.isclass(module) and obj.__name__ not in module.__dict__: + continue + + # See if there is a variable keyword argument + for argname, par in sig.parameters.items(): + if par.default is inspect._empty or \ + not par.kind == inspect.Parameter.KEYWORD_ONLY and \ + not par.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: + continue + + # Check to see if the default value is mutable + if par.default is not None and not \ + isinstance(par.default, (bool, int, float, tuple, str)): + pytest.fail( + f"function '{obj.__name__}' in module '{module.__name__}'" + f" has mutable default for keyword '{par.name}'") + pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy