Skip to content

Commit c6f0281

Browse files
committed
ENH: define matvec and vecmat gufuncs
1 parent 877ca75 commit c6f0281

File tree

8 files changed

+618
-48
lines changed

8 files changed

+618
-48
lines changed

numpy/__init__.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@
150150
isnan, isnat, isscalar, issubdtype, lcm, ldexp, left_shift, less,
151151
less_equal, lexsort, linspace, little_endian, log, log10, log1p, log2,
152152
logaddexp, logaddexp2, logical_and, logical_not, logical_or,
153-
logical_xor, logspace, long, longdouble, longlong, matmul,
153+
logical_xor, logspace, long, longdouble, longlong, matmul, matvec,
154154
matrix_transpose, max, maximum, may_share_memory, mean, memmap, min,
155155
min_scalar_type, minimum, mod, modf, moveaxis, multiply, nan, ndarray,
156156
ndim, nditer, negative, nested_iters, newaxis, nextafter, nonzero,
@@ -165,11 +165,11 @@
165165
str_, subtract, sum, swapaxes, take, tan, tanh, tensordot,
166166
timedelta64, trace, transpose, true_divide, trunc, typecodes, ubyte,
167167
ufunc, uint, uint16, uint32, uint64, uint8, uintc, uintp, ulong,
168-
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, void, vstack,
169-
where, zeros, zeros_like
168+
ulonglong, unsignedinteger, ushort, var, vdot, vecdot, vecmat, void,
169+
vstack, where, zeros, zeros_like
170170
)
171171

172-
# NOTE: It's still under discussion whether these aliases
172+
# NOTE: It's still under discussion whether these aliases
173173
# should be removed.
174174
for ta in ["float96", "float128", "complex192", "complex256"]:
175175
try:
@@ -184,20 +184,20 @@
184184
histogram, histogram_bin_edges, histogramdd
185185
)
186186
from .lib._nanfunctions_impl import (
187-
nanargmax, nanargmin, nancumprod, nancumsum, nanmax, nanmean,
187+
nanargmax, nanargmin, nancumprod, nancumsum, nanmax, nanmean,
188188
nanmedian, nanmin, nanpercentile, nanprod, nanquantile, nanstd,
189189
nansum, nanvar
190190
)
191191
from .lib._function_base_impl import (
192-
select, piecewise, trim_zeros, copy, iterable, percentile, diff,
192+
select, piecewise, trim_zeros, copy, iterable, percentile, diff,
193193
gradient, angle, unwrap, sort_complex, flip, rot90, extract, place,
194194
vectorize, asarray_chkfinite, average, bincount, digitize, cov,
195195
corrcoef, median, sinc, hamming, hanning, bartlett, blackman,
196196
kaiser, trapz, i0, meshgrid, delete, insert, append, interp, quantile
197197
)
198198
from .lib._twodim_base_impl import (
199-
diag, diagflat, eye, fliplr, flipud, tri, triu, tril, vander,
200-
histogram2d, mask_indices, tril_indices, tril_indices_from,
199+
diag, diagflat, eye, fliplr, flipud, tri, triu, tril, vander,
200+
histogram2d, mask_indices, tril_indices, tril_indices_from,
201201
triu_indices, triu_indices_from
202202
)
203203
from .lib._shape_base_impl import (
@@ -206,7 +206,7 @@
206206
take_along_axis, tile, vsplit
207207
)
208208
from .lib._type_check_impl import (
209-
iscomplexobj, isrealobj, imag, iscomplex, isreal, nan_to_num, real,
209+
iscomplexobj, isrealobj, imag, iscomplex, isreal, nan_to_num, real,
210210
real_if_close, typename, mintypecode, common_type
211211
)
212212
from .lib._arraysetops_impl import (
@@ -231,7 +231,7 @@
231231
)
232232
from .lib._index_tricks_impl import (
233233
diag_indices_from, diag_indices, fill_diagonal, ndindex, ndenumerate,
234-
ix_, c_, r_, s_, ogrid, mgrid, unravel_index, ravel_multi_index,
234+
ix_, c_, r_, s_, ogrid, mgrid, unravel_index, ravel_multi_index,
235235
index_exp
236236
)
237237
from . import matrixlib as _mat
@@ -244,7 +244,7 @@
244244
# (experimental label) are not added here, because `from numpy import *`
245245
# must not raise any warnings - that's too disruptive.
246246
__numpy_submodules__ = {
247-
"linalg", "fft", "dtypes", "random", "polynomial", "ma",
247+
"linalg", "fft", "dtypes", "random", "polynomial", "ma",
248248
"exceptions", "lib", "ctypeslib", "testing", "typing",
249249
"f2py", "test", "rec", "char", "core", "strings",
250250
}
@@ -391,7 +391,7 @@ def __getattr__(attr):
391391

392392
if attr in __former_attrs__:
393393
raise AttributeError(__former_attrs__[attr])
394-
394+
395395
if attr in __expired_attributes__:
396396
raise AttributeError(
397397
f"`np.{attr}` was removed in the NumPy 2.0 release. "
@@ -414,7 +414,7 @@ def __dir__():
414414
globals().keys() | __numpy_submodules__
415415
)
416416
public_symbols -= {
417-
"matrixlib", "matlib", "tests", "conftest", "version",
417+
"matrixlib", "matlib", "tests", "conftest", "version",
418418
"compat", "distutils", "array_api"
419419
}
420420
return list(public_symbols)
@@ -488,7 +488,7 @@ def _mac_os_check():
488488
def hugepage_setup():
489489
"""
490490
We usually use madvise hugepages support, but on some old kernels it
491-
is slow and thus better avoided. Specifically kernel version 4.6
491+
is slow and thus better avoided. Specifically kernel version 4.6
492492
had a bug fix which probably fixed this:
493493
https://github.com/torvalds/linux/commit/7cf91a98e607c2f935dbcc177d70011e95b8faff
494494
"""
@@ -497,7 +497,7 @@ def hugepage_setup():
497497
# If there is an issue with parsing the kernel version,
498498
# set use_hugepage to 0. Usage of LooseVersion will handle
499499
# the kernel version parsing better, but avoided since it
500-
# will increase the import time.
500+
# will increase the import time.
501501
# See: #16679 for related discussion.
502502
try:
503503
use_hugepage = 1

numpy/__init__.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3308,6 +3308,7 @@ logical_not: _UFunc_Nin1_Nout1[L['logical_not'], L[20], None]
33083308
logical_or: _UFunc_Nin2_Nout1[L['logical_or'], L[20], L[False]]
33093309
logical_xor: _UFunc_Nin2_Nout1[L['logical_xor'], L[19], L[False]]
33103310
matmul: _GUFunc_Nin2_Nout1[L['matmul'], L[19], None]
3311+
matvec: _GUFunc_Nin2_Nout1[L['matvec'], L[19], None]
33113312
maximum: _UFunc_Nin2_Nout1[L['maximum'], L[21], None]
33123313
minimum: _UFunc_Nin2_Nout1[L['minimum'], L[21], None]
33133314
mod: _UFunc_Nin2_Nout1[L['remainder'], L[16], None]
@@ -3337,6 +3338,7 @@ tanh: _UFunc_Nin1_Nout1[L['tanh'], L[8], None]
33373338
true_divide: _UFunc_Nin2_Nout1[L['true_divide'], L[11], None]
33383339
trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]
33393340
vecdot: _GUFunc_Nin2_Nout1[L['vecdot'], L[19], None]
3341+
vecmat: _GUFunc_Nin2_Nout1[L['vecmat'], L[19], None]
33403342

33413343
abs = absolute
33423344
acos = arccos

numpy/_core/code_generators/generate_umath.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,22 @@ def english_upper(s):
11501150
TD(O),
11511151
signature='(n),(n)->()',
11521152
),
1153+
'matvec':
1154+
Ufunc(2, 1, None,
1155+
docstrings.get('numpy._core.umath.matvec'),
1156+
"PyUFunc_SimpleUniformOperationTypeResolver",
1157+
TD(notimes_or_obj),
1158+
TD(O),
1159+
signature='(m,n),(n)->(m)',
1160+
),
1161+
'vecmat':
1162+
Ufunc(2, 1, None,
1163+
docstrings.get('numpy._core.umath.vecmat'),
1164+
"PyUFunc_SimpleUniformOperationTypeResolver",
1165+
TD(notimes_or_obj),
1166+
TD(O),
1167+
signature='(n),(n,m)->(m)',
1168+
),
11531169
'str_len':
11541170
Ufunc(1, 1, Zero,
11551171
docstrings.get('numpy._core.umath.str_len'),

numpy/_core/code_generators/ufunc_docstrings.py

Lines changed: 116 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def add_newdoc(place, name, doc):
4444

4545
skip = (
4646
# gufuncs do not use the OUT_SCALAR replacement strings
47-
'matmul', 'vecdot',
47+
'matmul', 'vecdot', 'matvec', 'vecmat',
4848
# clip has 3 inputs, which is not handled by this
4949
'clip',
5050
)
@@ -2880,8 +2880,8 @@ def add_newdoc(place, name, doc):
28802880
Input arrays, scalars not allowed.
28812881
out : ndarray, optional
28822882
A location into which the result is stored. If provided, it must have
2883-
a shape that the broadcasted shape of `x1` and `x2` with the last axis
2884-
removed. If not provided or None, a freshly-allocated array is used.
2883+
the broadcasted shape of `x1` and `x2` with the last axis removed.
2884+
If not provided or None, a freshly-allocated array is used.
28852885
**kwargs
28862886
For other keyword-only arguments, see the
28872887
:ref:`ufunc docs <ufuncs.kwargs>`.
@@ -2917,6 +2917,119 @@ def add_newdoc(place, name, doc):
29172917
.. versionadded:: 2.0.0
29182918
""")
29192919

2920+
add_newdoc('numpy._core.umath', 'matvec',
2921+
"""
2922+
Matrix-vector dot product of two arrays.
2923+
2924+
Let :math:`\\mathbf{A}` be a maxtrix in ``x1`` and :math:`\\mathbf{b}` be
2925+
a vector in ``x2``. The matrix-vector product is defined as:
2926+
2927+
.. math::
2928+
\\mathbf{A} \\cdot \\mathbf{b} = \\sum_{j=0}^{n-1} A_{ij} b_j
2929+
2930+
where the sum is over the last dimensions in ``x1`` and ``x2``
2931+
(unless ``axes`` is specified).
2932+
2933+
Parameters
2934+
----------
2935+
x1, x2 : array_like
2936+
Input arrays, scalars not allowed.
2937+
out : ndarray, optional
2938+
A location into which the result is stored. If provided, it must have
2939+
the broadcasted shape of ``x1`` and ``x2`` with the summation axis
2940+
removed. If not provided or None, a freshly-allocated array is used.
2941+
**kwargs
2942+
For other keyword-only arguments, see the
2943+
:ref:`ufunc docs <ufuncs.kwargs>`.
2944+
2945+
Returns
2946+
-------
2947+
y : ndarray
2948+
The matrix-vector product of the inputs.
2949+
2950+
Raises
2951+
------
2952+
ValueError
2953+
If the last dimensions of ``x1`` and ``x2`` are not the same size.
2954+
2955+
If a scalar value is passed in.
2956+
2957+
See Also
2958+
--------
2959+
vecmat : Vector-matrix product.
2960+
einsum : Einstein summation convention.
2961+
2962+
Examples
2963+
--------
2964+
Project a matrix along a given direction.
2965+
2966+
>>> a = np.array([[0., 5., 0.], [0., 0., 10.], [0., 6., 8.]])
2967+
>>> n = np.array([0., 0.6, 0.8])
2968+
>>> np.matvec(a, n)
2969+
array([ 3., 8., 10.])
2970+
2971+
.. versionadded:: 2.0.0
2972+
""")
2973+
2974+
add_newdoc('numpy._core.umath', 'vecmat',
2975+
"""
2976+
Vector-matrix dot product of two arrays.
2977+
2978+
Let :math:`\\mathbf{b}` be a vector in ``x1`` and :math:`\\mathbf{A}` be
2979+
a matrix in ``x2``. The vector-matrix product is defined as:
2980+
2981+
.. math::
2982+
\\mathbf{b} \\cdot \\mathbf{A} = \\sum_{i=0}^{n-1} \\overline{b_i}A_{ij}
2983+
2984+
where the sum is over the last dimension of ``x1`` and the one-but-last
2985+
dimensions in ``x2`` (unless `axes` is specified) and where
2986+
:math:`\\overline{b_i}` denotes the complex conjugate if :math:`b`
2987+
is complex and the identity otherwise.
2988+
2989+
Parameters
2990+
----------
2991+
x1, x2 : array_like
2992+
Input arrays, scalars not allowed.
2993+
out : ndarray, optional
2994+
A location into which the result is stored. If provided, it must have
2995+
the broadcasted shape of ``x1`` and ``x2`` with the summation axis
2996+
removed. If not provided or None, a freshly-allocated array is used.
2997+
**kwargs
2998+
For other keyword-only arguments, see the
2999+
:ref:`ufunc docs <ufuncs.kwargs>`.
3000+
3001+
Returns
3002+
-------
3003+
y : ndarray
3004+
The vector-matrix product of the inputs.
3005+
3006+
Raises
3007+
------
3008+
ValueError
3009+
If the last dimensions of ``x1`` and the one-but-last dimension of
3010+
``x2`` are not the same size.
3011+
3012+
If a scalar value is passed in.
3013+
3014+
See Also
3015+
--------
3016+
matvec : Matrix-vector product.
3017+
einsum : Einstein summation convention.
3018+
3019+
Examples
3020+
--------
3021+
Project a matrix along a given direction.
3022+
3023+
Project a vector along X and Y.
3024+
3025+
>>> n = np.array([0., 4., 2.])
3026+
>>> m = np.array([[1., 0., 0.], [0., 1., 0.], [0., 0., 0.]])
3027+
>>> np.vecmat(n, m)
3028+
array([ 0., 4., 0.])
3029+
3030+
.. versionadded:: 2.0.0
3031+
""")
3032+
29203033
add_newdoc('numpy._core.umath', 'modf',
29213034
"""
29223035
Return the fractional and integral parts of an array, element-wise.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy