Skip to content

Commit 6b61ed0

Browse files
committed
Add slicing access for state-space models with tests
1 parent ad6b49e commit 6b61ed0

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

control/statesp.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import math
5151
from copy import deepcopy
5252
from warnings import warn
53+
from collections.abc import Iterable
5354

5455
import numpy as np
5556
import scipy as sp
@@ -1215,17 +1216,16 @@ def append(self, other):
12151216

12161217
def __getitem__(self, indices):
12171218
"""Array style access"""
1218-
if len(indices) != 2:
1219+
if not isinstance(indices, Iterable) or len(indices) != 2:
12191220
raise IOError('must provide indices of length 2 for state space')
1220-
outdx = indices[0] if isinstance(indices[0], list) else [indices[0]]
1221-
inpdx = indices[1] if isinstance(indices[1], list) else [indices[1]]
1221+
outdx, inpdx = indices
1222+
if not isinstance(outdx, (int, slice)) or not isinstance(inpdx, (int, slice)):
1223+
raise TypeError(f"system indices must be integers or slices")
12221224
sysname = config.defaults['iosys.indexed_system_name_prefix'] + \
12231225
self.name + config.defaults['iosys.indexed_system_name_suffix']
12241226
return StateSpace(
12251227
self.A, self.B[:, inpdx], self.C[outdx, :], self.D[outdx, inpdx],
1226-
self.dt, name=sysname,
1227-
inputs=[self.input_labels[i] for i in list(inpdx)],
1228-
outputs=[self.output_labels[i] for i in list(outdx)])
1228+
self.dt, name=sysname, inputs=self.input_labels[inpdx], outputs=self.output_labels[outdx])
12291229

12301230
def sample(self, Ts, method='zoh', alpha=None, prewarp_frequency=None,
12311231
name=None, copy_names=True, **kwargs):

control/tests/statesp_test.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -463,24 +463,38 @@ def test_append_tf(self):
463463
np.testing.assert_array_almost_equal(sys3c.A[:3, 3:], np.zeros((3, 2)))
464464
np.testing.assert_array_almost_equal(sys3c.A[3:, :3], np.zeros((2, 3)))
465465

466-
def test_array_access_ss(self):
467-
466+
def test_array_access_ss_failure(self):
467+
sys1 = StateSpace(
468+
[[1., 2.], [3., 4.]],
469+
[[5., 6.], [6., 8.]],
470+
[[9., 10.], [11., 12.]],
471+
[[13., 14.], [15., 16.]], 1,
472+
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
473+
with pytest.raises(IOError):
474+
sys1[0]
475+
476+
@pytest.mark.parametrize("outdx, inpdx",
477+
[(0, 1),
478+
(slice(0, 1, 1), 1),
479+
(0, slice(1, 2, 1)),
480+
(slice(0, 1, 1), slice(1, 2, 1))])
481+
def test_array_access_ss(self, outdx, inpdx):
468482
sys1 = StateSpace(
469483
[[1., 2.], [3., 4.]],
470484
[[5., 6.], [6., 8.]],
471485
[[9., 10.], [11., 12.]],
472486
[[13., 14.], [15., 16.]], 1,
473487
inputs=['u0', 'u1'], outputs=['y0', 'y1'])
474488

475-
sys1_01 = sys1[0, 1]
489+
sys1_01 = sys1[outdx, inpdx]
476490
np.testing.assert_array_almost_equal(sys1_01.A,
477491
sys1.A)
478492
np.testing.assert_array_almost_equal(sys1_01.B,
479493
sys1.B[:, 1:2])
480494
np.testing.assert_array_almost_equal(sys1_01.C,
481495
sys1.C[0:1, :])
482496
np.testing.assert_array_almost_equal(sys1_01.D,
483-
sys1.D[0, 1])
497+
sys1.D[0:1, 1:2])
484498

485499
assert sys1.dt == sys1_01.dt
486500
assert sys1_01.input_labels == ['u1']

control/xferfcn.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
4848
"""
4949

50+
from collections.abc import Iterable
51+
5052
# External function declarations
5153
import numpy as np
5254
from numpy import angle, array, empty, finfo, ndarray, ones, \
@@ -758,7 +760,12 @@ def __pow__(self, other):
758760
return (TransferFunction([1], [1]) / self) * (self**(other + 1))
759761

760762
def __getitem__(self, key):
763+
if not isinstance(key, Iterable) or len(key) != 2:
764+
raise IOError('must provide indices of length 2 for state space')
765+
761766
key1, key2 = key
767+
if not isinstance(key1, (int, slice)) or not isinstance(key2, (int, slice)):
768+
raise TypeError(f"system indices must be integers or slices")
762769

763770
# pre-process
764771
if isinstance(key1, int):

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy