Skip to content

Commit 47262f5

Browse files
committed
update coeff handling to allow multi-variable basis
1 parent 2901cbe commit 47262f5

File tree

2 files changed

+43
-43
lines changed

2 files changed

+43
-43
lines changed

control/optimal.py

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -268,17 +268,16 @@ def _cost_function(self, coeffs):
268268
start_time = time.process_time()
269269
logging.info("_cost_function called at: %g", start_time)
270270

271-
# Retrieve the initial state and reshape the input vector
271+
# Retrieve the saved initial state
272272
x = self.x
273-
coeffs = coeffs.reshape((self.system.ninputs, -1))
274273

275-
# Compute time points (if basis present)
274+
# Compute inputs
276275
if self.basis:
277276
if self.log:
278277
logging.debug("coefficients = " + str(coeffs))
279278
inputs = self._coeffs_to_inputs(coeffs)
280279
else:
281-
inputs = coeffs
280+
inputs = coeffs.reshape((self.system.ninputs, -1))
282281

283282
# See if we already have a simulation for this condition
284283
if np.array_equal(coeffs, self.last_coeffs) and \
@@ -391,15 +390,14 @@ def _constraint_function(self, coeffs):
391390
start_time = time.process_time()
392391
logging.info("_constraint_function called at: %g", start_time)
393392

394-
# Retrieve the initial state and reshape the input vector
393+
# Retrieve the initial state
395394
x = self.x
396-
coeffs = coeffs.reshape((self.system.ninputs, -1))
397395

398-
# Compute time points (if basis present)
396+
# Compute input at time points
399397
if self.basis:
400398
inputs = self._coeffs_to_inputs(coeffs)
401399
else:
402-
inputs = coeffs
400+
inputs = coeffs.reshape((self.system.ninputs, -1))
403401

404402
# See if we already have a simulation for this condition
405403
if np.array_equal(coeffs, self.last_coeffs) \
@@ -473,15 +471,14 @@ def _eqconst_function(self, coeffs):
473471
start_time = time.process_time()
474472
logging.info("_eqconst_function called at: %g", start_time)
475473

476-
# Retrieve the initial state and reshape the input vector
474+
# Retrieve the initial state
477475
x = self.x
478-
coeffs = coeffs.reshape((self.system.ninputs, -1))
479476

480-
# Compute time points (if basis present)
477+
# Compute input at time points
481478
if self.basis:
482479
inputs = self._coeffs_to_inputs(coeffs)
483480
else:
484-
inputs = coeffs
481+
inputs = coeffs.reshape((self.system.ninputs, -1))
485482

486483
# See if we already have a simulation for this condition
487484
if np.array_equal(coeffs, self.last_coeffs) and \
@@ -609,34 +606,36 @@ def _inputs_to_coeffs(self, inputs):
609606
return inputs
610607

611608
# Solve least squares problems (M x = b) for coeffs on each input
612-
coeffs = np.zeros((self.system.ninputs, self.basis.N))
609+
coeffs = []
613610
for i in range(self.system.ninputs):
614611
# Set up the matrices to get inputs
615-
M = np.zeros((self.timepts.size, self.basis.N))
612+
M = np.zeros((self.timepts.size, self.basis.var_ncoefs(i)))
616613
b = np.zeros(self.timepts.size)
617614

618615
# Evaluate at each time point and for each basis function
619616
# TODO: vectorize
620617
for j, t in enumerate(self.timepts):
621-
for k in range(self.basis.N):
618+
for k in range(self.basis.var_ncoefs(i)):
622619
M[j, k] = self.basis(k, t)
623-
b[j] = inputs[i, j]
620+
b[j] = inputs[i, j]
624621

625622
# Solve a least squares problem for the coefficients
626623
alpha, residuals, rank, s = np.linalg.lstsq(M, b, rcond=None)
627-
coeffs[i, :] = alpha
624+
coeffs.append(alpha)
628625

629-
return coeffs
626+
return np.hstack(coeffs)
630627

631628
# Utility function to convert coefficient vector to input vector
632629
def _coeffs_to_inputs(self, coeffs):
633630
# TODO: vectorize
634631
inputs = np.zeros((self.system.ninputs, self.timepts.size))
635-
for i, t in enumerate(self.timepts):
636-
for k in range(self.basis.N):
637-
phi_k = self.basis(k, t)
638-
for inp in range(self.system.ninputs):
639-
inputs[inp, i] += coeffs[inp, k] * phi_k
632+
offset = 0
633+
for i in range(self.system.ninputs):
634+
length = self.basis.var_ncoefs(i)
635+
for j, t in enumerate(self.timepts):
636+
for k in range(length):
637+
inputs[i, j] += coeffs[offset + k] * self.basis(k, t)
638+
offset += length
640639
return inputs
641640

642641
#
@@ -680,7 +679,7 @@ def _print_statistics(self, reset=True):
680679

681680
# Compute the optimal trajectory from the current state
682681
def compute_trajectory(
683-
self, x, squeeze=None, transpose=None, return_states=None,
682+
self, x, squeeze=None, transpose=None, return_states=True,
684683
initial_guess=None, print_summary=True, **kwargs):
685684
"""Compute the optimal input at state x
686685
@@ -689,8 +688,7 @@ def compute_trajectory(
689688
x : array-like or number, optional
690689
Initial state for the system.
691690
return_states : bool, optional
692-
If True, return the values of the state at each time (default =
693-
False).
691+
If True (default), return the values of the state at each time.
694692
squeeze : bool, optional
695693
If True and if the system has a single output, return the system
696694
output as a 1D array rather than a 2D array. If False, return the
@@ -837,7 +835,7 @@ class OptimalControlResult(sp.optimize.OptimizeResult):
837835
838836
"""
839837
def __init__(
840-
self, ocp, res, return_states=False, print_summary=False,
838+
self, ocp, res, return_states=True, print_summary=False,
841839
transpose=None, squeeze=None):
842840
"""Create a OptimalControlResult object"""
843841

@@ -848,14 +846,11 @@ def __init__(
848846
# Remember the optimal control problem that we solved
849847
self.problem = ocp
850848

851-
# Reshape and process the input vector
852-
coeffs = res.x.reshape((ocp.system.ninputs, -1))
853-
854-
# Compute time points (if basis present)
849+
# Compute input at time points
855850
if ocp.basis:
856-
inputs = ocp._coeffs_to_inputs(coeffs)
851+
inputs = ocp._coeffs_to_inputs(res.x)
857852
else:
858-
inputs = coeffs
853+
inputs = res.x.reshape((ocp.system.ninputs, -1))
859854

860855
# See if we got an answer
861856
if not res.success:
@@ -894,7 +889,7 @@ def __init__(
894889
def solve_ocp(
895890
sys, horizon, X0, cost, trajectory_constraints=None, terminal_cost=None,
896891
terminal_constraints=[], initial_guess=None, basis=None, squeeze=None,
897-
transpose=None, return_states=False, log=False, **kwargs):
892+
transpose=None, return_states=True, log=False, **kwargs):
898893

899894
"""Compute the solution to an optimal control problem
900895
@@ -949,7 +944,7 @@ def solve_ocp(
949944
If `True`, turn on logging messages (using Python logging module).
950945
951946
return_states : bool, optional
952-
If True, return the values of the state at each time (default = False).
947+
If True, return the values of the state at each time (default = True).
953948
954949
squeeze : bool, optional
955950
If True and if the system has a single output, return the system

control/tests/optimal_test.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,8 +300,9 @@ def test_terminal_constraints(sys_args):
300300
np.testing.assert_almost_equal(res.inputs, u1)
301301

302302
# Re-run using a basis function and see if we get the same answer
303-
res = opt.solve_ocp(sys, time, x0, cost, terminal_constraints=final_point,
304-
basis=flat.BezierFamily(8, Tf))
303+
res = opt.solve_ocp(
304+
sys, time, x0, cost, terminal_constraints=final_point,
305+
basis=flat.BezierFamily(8, Tf))
305306

306307
# Final point doesn't affect cost => don't need to test
307308
np.testing.assert_almost_equal(
@@ -471,8 +472,12 @@ def test_ocp_argument_errors():
471472
sys, time, x0, cost, terminal_constraints=constraints)
472473

473474

474-
def test_optimal_basis_simple():
475-
sys = ct.ss2io(ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1))
475+
@pytest.mark.parametrize("basis", [
476+
flat.PolyFamily(4), flat.PolyFamily(6),
477+
flat.BezierFamily(4), flat.BSplineFamily([0, 4, 8], 6)
478+
])
479+
def test_optimal_basis_simple(basis):
480+
sys = ct.ss([[1, 1], [0, 1]], [[1], [0.5]], np.eye(2), 0, 1)
476481

477482
# State and input constraints
478483
constraints = [
@@ -492,7 +497,7 @@ def test_optimal_basis_simple():
492497
# Basic optimal control problem
493498
res1 = opt.solve_ocp(
494499
sys, time, x0, cost, constraints,
495-
basis=flat.BezierFamily(4, Tf), return_x=True)
500+
terminal_cost=cost, basis=basis, return_x=True)
496501
assert res1.success
497502

498503
# Make sure the constraints were satisfied
@@ -503,14 +508,14 @@ def test_optimal_basis_simple():
503508
# Pass an initial guess and rerun
504509
res2 = opt.solve_ocp(
505510
sys, time, x0, cost, constraints, initial_guess=0.99*res1.inputs,
506-
basis=flat.BezierFamily(4, Tf), return_x=True)
511+
terminal_cost=cost, basis=basis, return_x=True)
507512
assert res2.success
508513
np.testing.assert_allclose(res2.inputs, res1.inputs, atol=0.01, rtol=0.01)
509514

510515
# Run with logging turned on for code coverage
511516
res3 = opt.solve_ocp(
512-
sys, time, x0, cost, constraints,
513-
basis=flat.BezierFamily(4, Tf), return_x=True, log=True)
517+
sys, time, x0, cost, constraints, terminal_cost=cost,
518+
basis=basis, return_x=True, log=True)
514519
assert res3.success
515520
np.testing.assert_almost_equal(res3.inputs, res1.inputs, decimal=3)
516521

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy