Skip to content

Commit 899c6c0

Browse files
committed
Fix bug in tensor matmul for matrix-vector case. einsum for matrix-vector also dispatches to matmul now which is much faster
1 parent 8f4c6ae commit 899c6c0

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

tensor/TensorFunctions.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ FASTOR_INLINE Tensor<T,I,K> matmul(const Tensor<T,I,J> &a, const Tensor<T,J,K> &
7171
}
7272

7373
template<typename T, size_t I, size_t J>
74-
FASTOR_INLINE Tensor<T,J> matmul(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
74+
FASTOR_INLINE Tensor<T,I> matmul(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
7575
// Hack clang to get around alignment
7676
#if defined(__llvm__) || defined(__clang__)
7777
unused(a);
7878
#endif
79-
Tensor<T,J> out;
79+
Tensor<T,I> out;
8080
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
8181
return out;
8282
}

tensor_algebra/einsum.h

+31-2
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,9 @@ einsum(const Tensor<T,I,J> & a, const Tensor<T,K,L> &b) {
444444
}
445445

446446

447-
// matmul dispatcher for 2nd order tensors
447+
// matmul dispatcher for 2nd order tensors (matrix-matrix)
448+
// also includes matrix-vector and vector-matrix when vector is of size
449+
// nx1 or 1xn
448450
template<class Ind0, class Ind1,
449451
typename T, size_t I, size_t J, size_t K,
450452
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
@@ -454,13 +456,40 @@ template<class Ind0, class Ind1,
454456
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[1],bool>::type = 0>
455457
FASTOR_INLINE Tensor<T,I,K>
456458
einsum(const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
457-
458459
Tensor<T,I,K> out;
459460
_matmul<T,I,J,K>(a.data(),b.data(),out.data());
460461
return out;
461462
}
462463

463464

465+
// matmul dispatcher for matrix-vector
466+
template<class Ind0, class Ind1,
467+
typename T, size_t I, size_t J,
468+
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
469+
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
470+
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
471+
FASTOR_INLINE Tensor<T,I>
472+
einsum(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
473+
Tensor<T,I> out;
474+
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
475+
return out;
476+
}
477+
478+
479+
// matmul dispatcher for vector-matrix
480+
template<class Ind0, class Ind1,
481+
typename T, size_t I, size_t J,
482+
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
483+
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
484+
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
485+
FASTOR_INLINE Tensor<T,J>
486+
einsum(const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
487+
Tensor<T,J> out;
488+
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
489+
return out;
490+
}
491+
492+
464493
// The following two overloads are provided for an external use case
465494
// A_ijk*B_kl
466495
template<class Ind0, class Ind1,

0 commit comments

Comments
 (0)