55
66submodule (stdlib_intrinsics) stdlib_intrinsics_matmul
77 use stdlib_linalg_blas, only: gemm
8+ use stdlib_linalg_state, only: linalg_state_type, linalg_error_handling, LINALG_VALUE_ERROR
89 use stdlib_constants
910 implicit none
1011
12+ character (len=* ), parameter :: this = " stdlib_matmul"
13+
1114contains
1215
1316 ! Algorithm for the optimal parenthesization of matrices
@@ -122,41 +125,76 @@ contains
122125
123126 end function matmul_chain_mult_ ${s}$_4
124127
125- pure module function stdlib_matmul_${s}$ (m1, m2, m3, m4, m5) result(r)
128+ pure module subroutine stdlib_matmul_sub_${s}$ (res, m1, m2, m3, m4, m5, err)
129+ ${t}$, intent (out ), allocatable :: res(:,:)
126130 ${t}$, intent (in ) :: m1(:,:), m2(:,:)
127131 ${t}$, intent (in ), optional :: m3(:,:), m4(:,:), m5(:,:)
128- ${t}$, allocatable :: r(:,:), temp(:,:), temp1(:,:)
132+ type(linalg_state_type), intent (out ), optional :: err
133+ ${t}$, allocatable :: temp(:,:), temp1(:,:)
129134 integer :: p(6 ), num_present, m, n, k
130135 integer , allocatable :: s(:,:)
131136
137+ type(linalg_state_type) :: err0
138+
132139 p(1 ) = size (m1, 1 )
133140 p(2 ) = size (m2, 1 )
134141 p(3 ) = size (m2, 2 )
135142
143+ if (size (m1, 2 ) /= p(2 )) then
144+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, ' matrices m1, m2 not of compatible sizes' )
145+ call linalg_error_handling(err0, err)
146+ allocate(res(0 , 0 ))
147+ return
148+ end if
149+
136150 num_present = 2
137151 if (present (m3)) then
152+
153+ if (size (m3, 1 ) /= p(3 )) then
154+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, ' matrices m2, m3 not of compatible sizes' )
155+ call linalg_error_handling(err0, err)
156+ allocate(res(0 , 0 ))
157+ return
158+ end if
159+
138160 p(3 ) = size (m3, 1 )
139161 p(4 ) = size (m3, 2 )
140162 num_present = num_present + 1
141163 end if
142164 if (present (m4)) then
165+
166+ if (size (m4, 1 ) /= p(4 )) then
167+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, ' matrices m3, m4 not of compatible sizes' )
168+ call linalg_error_handling(err0, err)
169+ allocate(res(0 , 0 ))
170+ return
171+ end if
172+
143173 p(4 ) = size (m4, 1 )
144174 p(5 ) = size (m4, 2 )
145175 num_present = num_present + 1
146176 end if
147177 if (present (m5)) then
178+
179+ if (size (m5, 1 ) /= p(5 )) then
180+ err0 = linalg_state_type(this, LINALG_VALUE_ERROR, ' matrices m4, m5 not of compatible sizes' )
181+ call linalg_error_handling(err0, err)
182+ allocate(res(0 , 0 ))
183+ return
184+ end if
185+
148186 p(5 ) = size (m5, 1 )
149187 p(6 ) = size (m5, 2 )
150188 num_present = num_present + 1
151189 end if
152190
153- allocate(r (p(1 ), p(num_present + 1 )))
191+ allocate(res (p(1 ), p(num_present + 1 )))
154192
155193 if (num_present == 2 ) then
156194 m = p(1 )
157195 n = p(3 )
158196 k = p(2 )
159- call gemm(' N' , ' N' , m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, r , m)
197+ call gemm(' N' , ' N' , m, n, k, one_${s}$, m1, m, m2, k, zero_${s}$, res , m)
160198 return
161199 end if
162200
@@ -166,10 +204,10 @@ contains
166204 s = matmul_chain_order(p(1 : num_present + 1 ))
167205
168206 if (num_present == 3 ) then
169- r = matmul_chain_mult_${s}$_3(m1, m2, m3, 1 , s, p(1 :4 ))
207+ res = matmul_chain_mult_${s}$_3(m1, m2, m3, 1 , s, p(1 :4 ))
170208 return
171209 else if (num_present == 4 ) then
172- r = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1 , s, p(1 :5 ))
210+ res = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1 , s, p(1 :5 ))
173211 return
174212 end if
175213
@@ -182,7 +220,7 @@ contains
182220 m = p(1 )
183221 n = p(6 )
184222 k = p(2 )
185- call gemm(' N' , ' N' , m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, r , m)
223+ call gemm(' N' , ' N' , m, n, k, one_${s}$, m1, m, temp, k, zero_${s}$, res , m)
186224 case (2 )
187225 ! (m1* m2)* (m3* m4* m5)
188226 m = p(1 )
@@ -195,7 +233,7 @@ contains
195233
196234 k = n
197235 n = p(6 )
198- call gemm(' N' , ' N' , m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r , m)
236+ call gemm(' N' , ' N' , m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res , m)
199237 case (3 )
200238 ! (m1* m2* m3)* (m4* m5)
201239 temp = matmul_chain_mult_${s}$_3(m1, m2, m3, 3 , s, p)
@@ -208,18 +246,35 @@ contains
208246
209247 k = m
210248 m = p(1 )
211- call gemm(' N' , ' N' , m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, r , m)
249+ call gemm(' N' , ' N' , m, n, k, one_${s}$, temp, m, temp1, k, zero_${s}$, res , m)
212250 case (4 )
213251 ! (m1* m2* m3* m4)* m5
214252 temp = matmul_chain_mult_${s}$_4(m1, m2, m3, m4, 1 , s, p)
215253 m = p(1 )
216254 n = p(6 )
217255 k = p(5 )
218- call gemm(' N' , ' N' , m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, r , m)
256+ call gemm(' N' , ' N' , m, n, k, one_${s}$, temp, m, m5, k, zero_${s}$, res , m)
219257 case default
220- error stop " stdlib_matmul: error: unexpected s(i,j)"
258+ error stop " stdlib_matmul: internal error: unexpected s(i,j)"
221259 end select
222260
261+ end subroutine stdlib_matmul_sub_${s}$
262+
263+ pure module function stdlib_matmul_pure_${s}$ (m1, m2, m3, m4, m5) result(r)
264+ ${t}$, intent (in ) :: m1(:,:), m2(:,:)
265+ ${t}$, intent (in ), optional :: m3(:,:), m4(:,:), m5(:,:)
266+ ${t}$, allocatable :: r(:,:)
267+
268+ call stdlib_matmul_sub(r, m1, m2, m3, m4, m5)
269+ end function stdlib_matmul_pure_${s}$
270+
271+ module function stdlib_matmul_ ${s}$ (m1, m2, m3, m4, m5, err) result(r)
272+ ${t}$, intent (in ) :: m1(:,:), m2(:,:)
273+ ${t}$, intent (in ), optional :: m3(:,:), m4(:,:), m5(:,:)
274+ type(linalg_state_type), intent (out ) :: err
275+ ${t}$, allocatable :: r(:,:)
276+
277+ call stdlib_matmul_sub(r, m1, m2, m3, m4, m5, err= err)
223278 end function stdlib_matmul_ ${s}$
224279
225280#:endfor
0 commit comments