Skip to content

Commit 1a276dc

Browse files
authored
Merge pull request #1033 from murrayrm/ctrlplot_refactor-27Jun2024
Move ctrlplot code prior to upcoming PR
2 parents da64e0e + 009b821 commit 1a276dc

File tree

9 files changed

+293
-282
lines changed

9 files changed

+293
-282
lines changed

control/ctrlplot.py

Lines changed: 227 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,36 @@
55

66
from os.path import commonprefix
77

8+
import matplotlib as mpl
89
import matplotlib.pyplot as plt
910
import numpy as np
1011

1112
from . import config
1213

1314
__all__ = ['suptitle', 'get_plot_axes']
1415

16+
#
17+
# Style parameters
18+
#
19+
20+
_ctrlplot_rcParams = mpl.rcParams.copy()
21+
_ctrlplot_rcParams.update({
22+
'axes.labelsize': 'small',
23+
'axes.titlesize': 'small',
24+
'figure.titlesize': 'medium',
25+
'legend.fontsize': 'x-small',
26+
'xtick.labelsize': 'small',
27+
'ytick.labelsize': 'small',
28+
})
29+
30+
31+
#
32+
# User functions
33+
#
34+
# The functions below can be used by users to modify ctrl plots or get
35+
# information about them.
36+
#
37+
1538

1639
def suptitle(
1740
title, fig=None, frame='axes', **kwargs):
@@ -35,7 +58,7 @@ def suptitle(
3558
Additional keywords (passed to matplotlib).
3659
3760
"""
38-
rcParams = config._get_param('freqplot', 'rcParams', kwargs, pop=True)
61+
rcParams = config._get_param('ctrlplot', 'rcParams', kwargs, pop=True)
3962

4063
if fig is None:
4164
fig = plt.gcf()
@@ -61,10 +84,10 @@ def suptitle(
6184
def get_plot_axes(line_array):
6285
"""Get a list of axes from an array of lines.
6386
64-
This function can be used to return the set of axes corresponding to
65-
the line array that is returned by `time_response_plot`. This is useful for
66-
generating an axes array that can be passed to subsequent plotting
67-
calls.
87+
This function can be used to return the set of axes corresponding
88+
to the line array that is returned by `time_response_plot`. This
89+
is useful for generating an axes array that can be passed to
90+
subsequent plotting calls.
6891
6992
Parameters
7093
----------
@@ -89,6 +112,125 @@ def get_plot_axes(line_array):
89112
#
90113
# Utility functions
91114
#
115+
# These functions are used by plotting routines to provide a consistent way
116+
# of processing and displaying information.
117+
#
118+
119+
120+
def _process_ax_keyword(
121+
axs, shape=(1, 1), rcParams=None, squeeze=False, clear_text=False):
122+
"""Utility function to process ax keyword to plotting commands.
123+
124+
This function processes the `ax` keyword to plotting commands. If no
125+
ax keyword is passed, the current figure is checked to see if it has
126+
the correct shape. If the shape matches the desired shape, then the
127+
current figure and axes are returned. Otherwise a new figure is
128+
created with axes of the desired shape.
129+
130+
Legacy behavior: some of the older plotting commands use a axes label
131+
to identify the proper axes for plotting. This behavior is supported
132+
through the use of the label keyword, but will only work if shape ==
133+
(1, 1) and squeeze == True.
134+
135+
"""
136+
if axs is None:
137+
fig = plt.gcf() # get current figure (or create new one)
138+
axs = fig.get_axes()
139+
140+
# Check to see if axes are the right shape; if not, create new figure
141+
# Note: can't actually check the shape, just the total number of axes
142+
if len(axs) != np.prod(shape):
143+
with plt.rc_context(rcParams):
144+
if len(axs) != 0:
145+
# Create a new figure
146+
fig, axs = plt.subplots(*shape, squeeze=False)
147+
else:
148+
# Create new axes on (empty) figure
149+
axs = fig.subplots(*shape, squeeze=False)
150+
fig.set_layout_engine('tight')
151+
fig.align_labels()
152+
else:
153+
# Use the existing axes, properly reshaped
154+
axs = np.asarray(axs).reshape(*shape)
155+
156+
if clear_text:
157+
# Clear out any old text from the current figure
158+
for text in fig.texts:
159+
text.set_visible(False) # turn off the text
160+
del text # get rid of it completely
161+
else:
162+
try:
163+
axs = np.asarray(axs).reshape(shape)
164+
except ValueError:
165+
raise ValueError(
166+
"specified axes are not the right shape; "
167+
f"got {axs.shape} but expecting {shape}")
168+
fig = axs[0, 0].figure
169+
170+
# Process the squeeze keyword
171+
if squeeze and shape == (1, 1):
172+
axs = axs[0, 0] # Just return the single axes object
173+
elif squeeze:
174+
axs = axs.squeeze()
175+
176+
return fig, axs
177+
178+
179+
# Turn label keyword into array indexed by trace, output, input
180+
# TODO: move to ctrlutil.py and update parameter names to reflect general use
181+
def _process_line_labels(label, ntraces, ninputs=0, noutputs=0):
182+
if label is None:
183+
return None
184+
185+
if isinstance(label, str):
186+
label = [label] * ntraces # single label for all traces
187+
188+
# Convert to an ndarray, if not done aleady
189+
try:
190+
line_labels = np.asarray(label)
191+
except ValueError:
192+
raise ValueError("label must be a string or array_like")
193+
194+
# Turn the data into a 3D array of appropriate shape
195+
# TODO: allow more sophisticated broadcasting (and error checking)
196+
try:
197+
if ninputs > 0 and noutputs > 0:
198+
if line_labels.ndim == 1 and line_labels.size == ntraces:
199+
line_labels = line_labels.reshape(ntraces, 1, 1)
200+
line_labels = np.broadcast_to(
201+
line_labels, (ntraces, ninputs, noutputs))
202+
else:
203+
line_labels = line_labels.reshape(ntraces, ninputs, noutputs)
204+
except ValueError:
205+
if line_labels.shape[0] != ntraces:
206+
raise ValueError("number of labels must match number of traces")
207+
else:
208+
raise ValueError("labels must be given for each input/output pair")
209+
210+
return line_labels
211+
212+
213+
# Get labels for all lines in an axes
214+
def _get_line_labels(ax, use_color=True):
215+
labels, lines = [], []
216+
last_color, counter = None, 0 # label unknown systems
217+
for i, line in enumerate(ax.get_lines()):
218+
label = line.get_label()
219+
if use_color and label.startswith("Unknown"):
220+
label = f"Unknown-{counter}"
221+
if last_color is None:
222+
last_color = line.get_color()
223+
elif last_color != line.get_color():
224+
counter += 1
225+
last_color = line.get_color()
226+
elif label[0] == '_':
227+
continue
228+
229+
if label not in labels:
230+
lines.append(line)
231+
labels.append(label)
232+
233+
return lines, labels
92234

93235

94236
# Utility function to make legend labels
@@ -160,3 +302,83 @@ def _find_axes_center(fig, axs):
160302
ylim = [min(ll[1], ylim[0]), max(ur[1], ylim[1])]
161303

162304
return (np.sum(xlim)/2, np.sum(ylim)/2)
305+
306+
307+
# Internal function to add arrows to a curve
308+
def _add_arrows_to_line2D(
309+
axes, line, arrow_locs=[0.2, 0.4, 0.6, 0.8],
310+
arrowstyle='-|>', arrowsize=1, dir=1):
311+
"""
312+
Add arrows to a matplotlib.lines.Line2D at selected locations.
313+
314+
Parameters:
315+
-----------
316+
axes: Axes object as returned by axes command (or gca)
317+
line: Line2D object as returned by plot command
318+
arrow_locs: list of locations where to insert arrows, % of total length
319+
arrowstyle: style of the arrow
320+
arrowsize: size of the arrow
321+
322+
Returns:
323+
--------
324+
arrows: list of arrows
325+
326+
Based on https://stackoverflow.com/questions/26911898/
327+
328+
"""
329+
# Get the coordinates of the line, in plot coordinates
330+
if not isinstance(line, mpl.lines.Line2D):
331+
raise ValueError("expected a matplotlib.lines.Line2D object")
332+
x, y = line.get_xdata(), line.get_ydata()
333+
334+
# Determine the arrow properties
335+
arrow_kw = {"arrowstyle": arrowstyle}
336+
337+
color = line.get_color()
338+
use_multicolor_lines = isinstance(color, np.ndarray)
339+
if use_multicolor_lines:
340+
raise NotImplementedError("multicolor lines not supported")
341+
else:
342+
arrow_kw['color'] = color
343+
344+
linewidth = line.get_linewidth()
345+
if isinstance(linewidth, np.ndarray):
346+
raise NotImplementedError("multiwidth lines not supported")
347+
else:
348+
arrow_kw['linewidth'] = linewidth
349+
350+
# Figure out the size of the axes (length of diagonal)
351+
xlim, ylim = axes.get_xlim(), axes.get_ylim()
352+
ul, lr = np.array([xlim[0], ylim[0]]), np.array([xlim[1], ylim[1]])
353+
diag = np.linalg.norm(ul - lr)
354+
355+
# Compute the arc length along the curve
356+
s = np.cumsum(np.sqrt(np.diff(x) ** 2 + np.diff(y) ** 2))
357+
358+
# Truncate the number of arrows if the curve is short
359+
# TODO: figure out a smarter way to do this
360+
frac = min(s[-1] / diag, 1)
361+
if len(arrow_locs) and frac < 0.05:
362+
arrow_locs = [] # too short; no arrows at all
363+
elif len(arrow_locs) and frac < 0.2:
364+
arrow_locs = [0.5] # single arrow in the middle
365+
366+
# Plot the arrows (and return list if patches)
367+
arrows = []
368+
for loc in arrow_locs:
369+
n = np.searchsorted(s, s[-1] * loc)
370+
371+
if dir == 1 and n == 0:
372+
# Move the arrow forward by one if it is at start of a segment
373+
n = 1
374+
375+
# Place the head of the arrow at the desired location
376+
arrow_head = [x[n], y[n]]
377+
arrow_tail = [x[n - dir], y[n - dir]]
378+
379+
p = mpl.patches.FancyArrowPatch(
380+
arrow_tail, arrow_head, transform=axes.transData, lw=0,
381+
**arrow_kw)
382+
axes.add_patch(p)
383+
arrows.append(p)
384+
return arrows

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy