@@ -444,7 +444,9 @@ einsum(const Tensor<T,I,J> & a, const Tensor<T,K,L> &b) {
444
444
}
445
445
446
446
447
- // matmul dispatcher for 2nd order tensors
447
+ // matmul dispatcher for 2nd order tensors (matrix-matrix)
448
+ // also includes matrix-vector and vector-matrix when vector is of size
449
+ // nx1 or 1xn
448
450
template <class Ind0 , class Ind1 ,
449
451
typename T, size_t I, size_t J, size_t K,
450
452
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
@@ -454,13 +456,40 @@ template<class Ind0, class Ind1,
454
456
Ind0::_IndexHolder[0 ] != Ind1::_IndexHolder[1 ],bool >::type = 0 >
455
457
FASTOR_INLINE Tensor<T,I,K>
456
458
einsum (const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
457
-
458
459
Tensor<T,I,K> out;
459
460
_matmul<T,I,J,K>(a.data (),b.data (),out.data ());
460
461
return out;
461
462
}
462
463
463
464
465
+ // matmul dispatcher for matrix-vector
466
+ template <class Ind0 , class Ind1 ,
467
+ typename T, size_t I, size_t J,
468
+ typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
469
+ Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
470
+ Ind0::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ],bool >::type = 0 >
471
+ FASTOR_INLINE Tensor<T,I>
472
+ einsum (const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
473
+ Tensor<T,I> out;
474
+ _matmul<T,I,J,1 >(a.data (),b.data (),out.data ());
475
+ return out;
476
+ }
477
+
478
+
479
+ // matmul dispatcher for vector-matrix
480
+ template <class Ind0 , class Ind1 ,
481
+ typename T, size_t I, size_t J,
482
+ typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
483
+ Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
484
+ Ind0::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ],bool >::type = 0 >
485
+ FASTOR_INLINE Tensor<T,J>
486
+ einsum (const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
487
+ Tensor<T,J> out;
488
+ _matmul<T,I,J,1 >(a.data (),b.data (),out.data ());
489
+ return out;
490
+ }
491
+
492
+
464
493
// The following two overloads are provided for an external use case
465
494
// A_ijk*B_kl
466
495
template <class Ind0 , class Ind1 ,
0 commit comments