diff --git a/doc/specs/stdlib_intrinsics.md b/doc/specs/stdlib_intrinsics.md index be23adeb5..971a55ecd 100644 --- a/doc/specs/stdlib_intrinsics.md +++ b/doc/specs/stdlib_intrinsics.md @@ -155,4 +155,42 @@ The output is a scalar of the same type and kind as to that of `x` and `y`. ```fortran {!example/intrinsics/example_dot_product.f90!} -``` \ No newline at end of file +``` + +### `stdlib_matmul` function + +#### Description + +The extension of the intrinsic function `matmul` to handle more than 2 and less than or equal to 5 matrices, with error handling using `linalg_state_type`. +The optimal parenthesization to minimize the number of scalar multiplications is done using the Algorithm as outlined in Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2. +The actual matrix multiplication is performed using the `gemm` interfaces. +It supports only `real` and `complex` matrices. + +#### Syntax + +`res = ` [[stdlib_intrinsics(module):stdlib_matmul(interface)]] ` (m1, m2, m3, m4, m5, err)` + +#### Status + +Experimental + +#### Class + +Function. + +#### Argument(s) + +`m1`, `m2`: 2D arrays of the same kind and type. `intent(in)` arguments. +`m3`,`m4`,`m5`: 2D arrays of the same kind and type as the other matrices. `intent(in), optional` arguments. +`err`: `type(linalg_state_type), intent(out), optional` argument. Can be used for elegant error handling. It is assigned `LINALG_VALUE_ERROR` + in case the matrices are not of compatible sizes. + +#### Result + +The output is a matrix of the appropriate size. + +#### Example + +```fortran +{!example/intrinsics/example_matmul.f90!} +``` diff --git a/example/intrinsics/CMakeLists.txt b/example/intrinsics/CMakeLists.txt index 1645ba8a1..162744b66 100644 --- a/example/intrinsics/CMakeLists.txt +++ b/example/intrinsics/CMakeLists.txt @@ -1,2 +1,3 @@ ADD_EXAMPLE(sum) -ADD_EXAMPLE(dot_product) \ No newline at end of file +ADD_EXAMPLE(dot_product) +ADD_EXAMPLE(matmul) diff --git a/example/intrinsics/example_matmul.f90 b/example/intrinsics/example_matmul.f90 new file mode 100644 index 000000000..b62215074 --- /dev/null +++ b/example/intrinsics/example_matmul.f90 @@ -0,0 +1,12 @@ +program example_matmul + use stdlib_intrinsics, only: stdlib_matmul + complex :: x(2, 2), y(2, 2), z(2, 2) + x = reshape([(0, 0), (1, 0), (1, 0), (0, 0)], [2, 2]) + y = reshape([(0, 0), (0, 1), (0, -1), (0, 0)], [2, 2]) ! pauli y-matrix + z = reshape([(1, 0), (0, 0), (0, 0), (-1, 0)], [2, 2]) + + print *, stdlib_matmul(x, y) ! should be iota*z + print *, stdlib_matmul(y, z, x) ! should be iota*identity + print *, stdlib_matmul(x, x, z, y) ! should be -iota*x + print *, stdlib_matmul(x, x, z, y, y) ! should be z +end program example_matmul diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c3cd99120..5e915dfed 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -19,6 +19,7 @@ set(fppFiles stdlib_hash_64bit_spookyv2.fypp stdlib_intrinsics_dot_product.fypp stdlib_intrinsics_sum.fypp + stdlib_intrinsics_matmul.fypp stdlib_intrinsics.fypp stdlib_io.fypp stdlib_io_npy.fypp @@ -32,14 +33,14 @@ set(fppFiles stdlib_linalg_kronecker.fypp stdlib_linalg_cross_product.fypp stdlib_linalg_eigenvalues.fypp - stdlib_linalg_solve.fypp + stdlib_linalg_solve.fypp stdlib_linalg_determinant.fypp stdlib_linalg_qr.fypp stdlib_linalg_inverse.fypp stdlib_linalg_pinv.fypp stdlib_linalg_norms.fypp stdlib_linalg_state.fypp - stdlib_linalg_svd.fypp + stdlib_linalg_svd.fypp stdlib_linalg_cholesky.fypp stdlib_linalg_schur.fypp stdlib_optval.fypp diff --git a/src/stdlib_intrinsics.fypp b/src/stdlib_intrinsics.fypp index b2c16a5a6..4b5751df6 100644 --- a/src/stdlib_intrinsics.fypp +++ b/src/stdlib_intrinsics.fypp @@ -8,6 +8,7 @@ module stdlib_intrinsics !!Alternative implementations of some Fortran intrinsic functions offering either faster and/or more accurate evaluation. !! ([Specification](../page/specs/stdlib_intrinsics.html)) use stdlib_kinds + use stdlib_linalg_state, only: linalg_state_type implicit none private @@ -146,6 +147,49 @@ module stdlib_intrinsics #:endfor end interface public :: kahan_kernel + + interface stdlib_matmul + !! version: experimental + !! + !!### Summary + !! compute the matrix multiplication of more than two matrices with a single function call. + !! ([Specification](../page/specs/stdlib_intrinsics.html#stdlib_matmul)) + !! + !!### Description + !! + !! matrix multiply more than two matrices with a single function call + !! the multiplication with the optimal parenthesization for efficiency of computation is done automatically + !! Supported data types are `real` and `complex`. + !! + !! Note: The matrices must be of compatible shapes to be multiplied + #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:) + end function stdlib_matmul_pure_${s}$ + + module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out) :: err + ${t}$, allocatable :: r(:,:) + end function stdlib_matmul_${s}$ + #:endfor + end interface stdlib_matmul + public :: stdlib_matmul + + ! internal interface + interface stdlib_matmul_sub + #:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err) + ${t}$, intent(out), allocatable :: res(:,:) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out), optional :: err + end subroutine stdlib_matmul_sub_${s}$ + #:endfor + end interface stdlib_matmul_sub contains diff --git a/src/stdlib_intrinsics_matmul.fypp b/src/stdlib_intrinsics_matmul.fypp new file mode 100644 index 000000000..c24f95374 --- /dev/null +++ b/src/stdlib_intrinsics_matmul.fypp @@ -0,0 +1,288 @@ +#:include "common.fypp" +#:set I_KINDS_TYPES = list(zip(INT_KINDS, INT_TYPES, INT_KINDS)) +#:set R_KINDS_TYPES = list(zip(REAL_KINDS, REAL_TYPES, REAL_SUFFIX)) +#:set C_KINDS_TYPES = list(zip(CMPLX_KINDS, CMPLX_TYPES, CMPLX_SUFFIX)) + +submodule (stdlib_intrinsics) stdlib_intrinsics_matmul + use stdlib_linalg_blas, only: gemm + use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR, LINALG_INTERNAL_ERROR + use stdlib_constants + implicit none + + character(len=*), parameter :: this = "stdlib_matmul" + +contains + + ! Algorithm for the optimal parenthesization of matrices + ! Reference: Cormen, "Introduction to Algorithms", 4ed, ch-14, section-2 + ! Internal use only! + pure function matmul_chain_order(p) result(s) + integer, intent(in) :: p(:) + integer :: s(1:size(p) - 2, 2:size(p) - 1), m(1:size(p) - 1, 1:size(p) - 1) + integer :: n, l, i, j, k, q + n = size(p) - 1 + m(:,:) = 0 + s(:,:) = 0 + + do l = 2, n + do i = 1, n - l + 1 + j = i + l - 1 + m(i,j) = huge(1) + + do k = i, j - 1 + q = m(i,k) + m(k+1,j) + p(i)*p(k+1)*p(j+1) + + if (q < m(i, j)) then + m(i,j) = q + s(i,j) = k + end if + end do + end do + end do + end function matmul_chain_order + +#:for k, t, s in R_KINDS_TYPES + C_KINDS_TYPES + + pure function matmul_chain_mult_${s}$_3 (m1, m2, m3, start, s, p) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:) + integer :: ord, m, n, k + ord = s(start, start + 2) + allocate(r(p(start), p(start + 3))) + + if (ord == start) then + ! m1*(m2*m3) + m = p(start + 1) + n = p(start + 3) + k = p(start + 2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m2, m, m3, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*m3 + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m, n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + m = p(start) + n = p(start + 3) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m3, k, zero_${s}$, r, m) + else + ! our internal functions are incorrent, abort + error stop this//": error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_3 + + pure function matmul_chain_mult_${s}$_4 (m1, m2, m3, m4, start, s, p) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:), m3(:,:), m4(:,:) + integer, intent(in) :: start, s(:,2:), p(:) + ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:) + integer :: ord, m, n, k + ord = s(start, start + 3) + allocate(r(p(start), p(start + 4))) + + if (ord == start) then + ! m1*(m2*m3*m4) + temp = matmul_chain_mult_${s}$_3(m2, m3, m4, start + 1, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 1) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r, m) + else if (ord == start + 1) then + ! (m1*m2)*(m3*m4) + m = p(start) + n = p(start + 2) + k = p(start + 1) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + m = p(start + 2) + n = p(start + 4) + k = p(start + 3) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m3, m, m4, k, zero_${s}$, temp1, m) + + m = p(start) + n = p(start + 4) + k = p(start + 2) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r, m) + else if (ord == start + 2) then + ! (m1*m2*m3)*m4 + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, start, s, p) + m = p(start) + n = p(start + 4) + k = p(start + 3) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m4, k, zero_${s}$, r, m) + else + ! our internal functions are incorrent, abort + error stop this//": error: unexpected s(i,j)" + end if + + end function matmul_chain_mult_${s}$_4 + + pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err) + ${t}$, intent(out), allocatable :: res(:,:) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out), optional :: err + ${t}$, allocatable :: temp(:,:), temp1(:,:) + integer :: p(6), num_present, m, n, k + integer, allocatable :: s(:,:) + + type(linalg_state_type) :: err0 + + p(1) = size(m1, 1) + p(2) = size(m2, 1) + p(3) = size(m2, 2) + + if (size(m1, 2) /= p(2)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m1=',shape(m1),& + ', m2=',shape(m2),'have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + num_present = 2 + if (present(m3)) then + + if (size(m3, 1) /= p(3)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m2=',shape(m2), & + ', m3=',shape(m3),'have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + p(3) = size(m3, 1) + p(4) = size(m3, 2) + num_present = num_present + 1 + end if + if (present(m4)) then + + if (size(m4, 1) /= p(4)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m3=',shape(m3), & + ', m4=',shape(m4),' have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + p(4) = size(m4, 1) + p(5) = size(m4, 2) + num_present = num_present + 1 + end if + if (present(m5)) then + + if (size(m5, 1) /= p(5)) then + err0 = linalg_state_type(this, LINALG_VALUE_ERROR, 'matrices m4=',shape(m4), & + ', m5=',shape(m5),' have incompatible sizes') + call linalg_error_handling(err0, err) + allocate(res(0, 0)) + return + end if + + p(5) = size(m5, 1) + p(6) = size(m5, 2) + num_present = num_present + 1 + end if + + allocate(res(p(1), p(num_present + 1))) + + if (num_present == 2) then + m = p(1) + n = p(3) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res, m) + return + end if + + ! Now num_present >= 3 + allocate(s(1:num_present - 1, 2:num_present)) + + s = matmul_chain_order(p(1: num_present + 1)) + + if (num_present == 3) then + res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1, s, p(1:4)) + return + else if (num_present == 4) then + res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p(1:5)) + return + end if + + ! Now num_present is 5 + + select case (s(1, 5)) + case (1) + ! m1*(m2*m3*m4*m5) + temp = matmul_chain_mult_${s}$_4(m2, m3, m4, m5, 2, s, p) + m = p(1) + n = p(6) + k = p(2) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res, m) + case (2) + ! (m1*m2)*(m3*m4*m5) + m = p(1) + n = p(3) + k = p(2) + allocate(temp(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, temp, m) + + temp1 = matmul_chain_mult_${s}$_3(m3, m4, m5, 3, s, p) + + k = n + n = p(6) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m) + case (3) + ! (m1*m2*m3)*(m4*m5) + temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3, s, p) + + m = p(4) + n = p(6) + k = p(5) + allocate(temp1(m,n)) + call gemm('N', 'N', m, n, k, one_${s}$, m4, m, m5, k, zero_${s}$, temp1, m) + + k = m + m = p(1) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res, m) + case (4) + ! (m1*m2*m3*m4)*m5 + temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1, s, p) + m = p(1) + n = p(6) + k = p(5) + call gemm('N', 'N', m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res, m) + case default + err0 = linalg_state_type(this,LINALG_INTERNAL_ERROR,"internal error: unexpected s(i,j)") + call linalg_error_handling(err0,err) + end select + + end subroutine stdlib_matmul_sub_${s}$ + + pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + ${t}$, allocatable :: r(:,:) + + call stdlib_matmul_sub(r, m1, m2, m3, m4, m5) + end function stdlib_matmul_pure_${s}$ + + module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5, err) result(r) + ${t}$, intent(in) :: m1(:,:), m2(:,:) + ${t}$, intent(in), optional :: m3(:,:), m4(:,:), m5(:,:) + type(linalg_state_type), intent(out) :: err + ${t}$, allocatable :: r(:,:) + + call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err=err) + end function stdlib_matmul_${s}$ + +#:endfor +end submodule stdlib_intrinsics_matmul diff --git a/test/intrinsics/test_intrinsics.fypp b/test/intrinsics/test_intrinsics.fypp index 8aefe09d3..10a17f6ab 100644 --- a/test/intrinsics/test_intrinsics.fypp +++ b/test/intrinsics/test_intrinsics.fypp @@ -7,6 +7,7 @@ module test_intrinsics use testdrive, only : new_unittest, unittest_type, error_type, check, skip_test use stdlib_kinds, only: sp, dp, xdp, qp, int8, int16, int32, int64 use stdlib_intrinsics + use stdlib_linalg_state, only: linalg_state_type, LINALG_VALUE_ERROR, operator(==) use stdlib_math, only: swap implicit none @@ -19,7 +20,8 @@ subroutine collect_suite(testsuite) testsuite = [ & new_unittest('sum', test_sum), & - new_unittest('dot_product', test_dot_product) & + new_unittest('dot_product', test_dot_product), & + new_unittest('matmul', test_matmul) & ] end subroutine @@ -249,6 +251,45 @@ subroutine test_dot_product(error) #:endfor end subroutine + +subroutine test_matmul(error) + type(error_type), allocatable, intent(out) :: error + type(linalg_state_type) :: linerr + real :: a(2, 3), b(3, 4), c(3, 2), d(2, 2) + + d = stdlib_matmul(a, b, c, err=linerr) + call check(error, linerr == LINALG_VALUE_ERROR, "incompatible matrices are considered compatible") + if (allocated(error)) return + + #:for k, t, s in R_KINDS_TYPES + block + ${t}$ :: x(10,15), y(15,20), z(20,10), r(10,10), r1(10,10) + call random_number(x) + call random_number(y) + call random_number(z) + + r = stdlib_matmul(x, y, z) ! the optimal ordering would be (x(yz)) + r1 = matmul(matmul(x, y), z) ! the opposite order to induce a difference + + call check(error, all(abs(r-r1) <= epsilon(0._${k}$) * 150), "real, ${k}$, 3 args: error too large") + if (allocated(error)) return + end block + + block + ${t}$ :: x(10,15), y(15,20), z(20,10), w(10, 15), r(10,15), r1(10,15) + call random_number(x) + call random_number(y) + call random_number(z) + call random_number(w) + + r = stdlib_matmul(x, y, z, w) ! the optimal order would be ((x(yz))w) + r1 = matmul(matmul(x, y), matmul(z, w)) + + call check(error, all(abs(r-r1) <= epsilon(0._${k}$) * 800), "real, ${k}$, 4 args: error too large") + if (allocated(error)) return + end block + #:endfor +end subroutine test_matmul end module test_intrinsics @@ -276,4 +317,4 @@ program tester write(error_unit, '(i0, 1x, a)') stat, "test(s) failed!" error stop end if -end program \ No newline at end of file +end program pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy