Skip to content

Commit 8241ac8

Browse files
committed
Fix a bug in matrix-vector einsum
1 parent ec82d56 commit 8241ac8

File tree

5 files changed

+111
-46
lines changed

5 files changed

+111
-46
lines changed

dd/Makefile

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
all:
2+
$(CXX) main.cpp -o main -I../ -O3 -mavx

dd/main

13.6 KB
Binary file not shown.

dd/main.cpp

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include <Fastor.h>
2+
3+
using namespace Fastor;
4+
5+
6+
int main()
7+
{
8+
9+
enum{i,j,k};
10+
Tensor<double, 2, 2> a = {{1, 2}, {3, 4}};
11+
Tensor<double, 2> w = {1, 1};
12+
// Tensor<double,2> e3 = einsum<Index<i, j>, Index <i> >(a, w);
13+
// Tensor<double,2> e3 = einsum<Index<i, j>, Index <j> >(a, w);
14+
15+
// Tensor<double,2> e3 = einsum<Index<i>, Index <i,j> >(w, a);
16+
Tensor<double,2> e3 = einsum<Index<j>, Index <i,j> >(w, a);
17+
18+
19+
print(a,w,e3);
20+
21+
return 0;
22+
}

tensor_algebra/einsum.h

+76-46
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,82 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
376376
#endif
377377

378378

379+
380+
// matmul dispatcher for 2nd order tensors (matrix-matrix)
381+
// also includes matrix-vector and vector-matrix when vector is of size
382+
// nx1 or 1xn
383+
template<class Ind0, class Ind1,
384+
typename T, size_t I, size_t J, size_t K,
385+
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
386+
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
387+
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0] &&
388+
Ind0::_IndexHolder[1] != Ind1::_IndexHolder[1] &&
389+
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[1],bool>::type = 0>
390+
FASTOR_INLINE Tensor<T,I,K>
391+
einsum(const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
392+
Tensor<T,I,K> out;
393+
_matmul<T,I,J,K>(a.data(),b.data(),out.data());
394+
return out;
395+
}
396+
397+
398+
// matmul dispatcher for matrix-vector
399+
template<class Ind0, class Ind1,
400+
typename T, size_t I, size_t J,
401+
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
402+
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
403+
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[0]
404+
,bool>::type = 0>
405+
FASTOR_INLINE Tensor<T,I>
406+
einsum(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
407+
Tensor<T,I> out;
408+
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
409+
return out;
410+
}
411+
412+
// matmul dispatcher for matrix-vector
413+
template<class Ind0, class Ind1,
414+
typename T, size_t I, size_t J,
415+
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
416+
Ind0::_IndexHolder[0] == Ind1::_IndexHolder[0] &&
417+
Ind0::_IndexHolder[1] != Ind1::_IndexHolder[0],bool>::type = 0>
418+
FASTOR_INLINE Tensor<T,J>
419+
einsum(const Tensor<T,I,J> &a, const Tensor<T,I> &b) {
420+
Tensor<T,J> out;
421+
_matmul<T,1,I,J>(b.data(),a.data(),out.data());
422+
return out;
423+
}
424+
425+
426+
// matmul dispatcher for vector-matrix
427+
template<class Ind0, class Ind1,
428+
typename T, size_t I, size_t J,
429+
typename std::enable_if<Ind1::NoIndices==2 && Ind0::NoIndices==1 &&
430+
Ind1::_IndexHolder[0] == Ind0::_IndexHolder[0] &&
431+
Ind1::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
432+
FASTOR_INLINE Tensor<T,J>
433+
einsum(const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
434+
Tensor<T,J> out;
435+
_matmul<T,1,I,J>(a.data(),b.data(),out.data());
436+
return out;
437+
}
438+
439+
440+
// matmul dispatcher for vector-matrix
441+
template<class Ind0, class Ind1,
442+
typename T, size_t I, size_t J,
443+
typename std::enable_if<Ind1::NoIndices==2 && Ind0::NoIndices==1 &&
444+
Ind1::_IndexHolder[1] == Ind0::_IndexHolder[0] &&
445+
Ind1::_IndexHolder[0] != Ind0::_IndexHolder[0],bool>::type = 0>
446+
FASTOR_INLINE Tensor<T,I>
447+
einsum(const Tensor<T,J> &a, const Tensor<T,I,J> &b) {
448+
Tensor<T,I> out;
449+
_matmul<T,I,J,1>(b.data(),a.data(),out.data());
450+
return out;
451+
}
452+
453+
454+
379455
#ifdef __AVX__
380456

381457
// Specific overloads
@@ -444,52 +520,6 @@ einsum(const Tensor<T,I,J> & a, const Tensor<T,K,L> &b) {
444520
}
445521

446522

447-
// matmul dispatcher for 2nd order tensors (matrix-matrix)
448-
// also includes matrix-vector and vector-matrix when vector is of size
449-
// nx1 or 1xn
450-
template<class Ind0, class Ind1,
451-
typename T, size_t I, size_t J, size_t K,
452-
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
453-
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
454-
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0] &&
455-
Ind0::_IndexHolder[1] != Ind1::_IndexHolder[1] &&
456-
Ind0::_IndexHolder[0] != Ind1::_IndexHolder[1],bool>::type = 0>
457-
FASTOR_INLINE Tensor<T,I,K>
458-
einsum(const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
459-
Tensor<T,I,K> out;
460-
_matmul<T,I,J,K>(a.data(),b.data(),out.data());
461-
return out;
462-
}
463-
464-
465-
// matmul dispatcher for matrix-vector
466-
template<class Ind0, class Ind1,
467-
typename T, size_t I, size_t J,
468-
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
469-
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
470-
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
471-
FASTOR_INLINE Tensor<T,I>
472-
einsum(const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
473-
Tensor<T,I> out;
474-
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
475-
return out;
476-
}
477-
478-
479-
// matmul dispatcher for vector-matrix
480-
template<class Ind0, class Ind1,
481-
typename T, size_t I, size_t J,
482-
typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
483-
Ind0::_IndexHolder[1] == Ind1::_IndexHolder[0] &&
484-
Ind0::_IndexHolder[1] != Ind0::_IndexHolder[0],bool>::type = 0>
485-
FASTOR_INLINE Tensor<T,J>
486-
einsum(const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
487-
Tensor<T,J> out;
488-
_matmul<T,I,J,1>(a.data(),b.data(),out.data());
489-
return out;
490-
}
491-
492-
493523
// The following two overloads are provided for an external use case
494524
// A_ijk*B_kl
495525
template<class Ind0, class Ind1,

tests/test_einsum.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,17 @@ void run() {
230230
assert(abs(As.sum() - Bs4.sum()) < BigTol);
231231
}
232232

233+
{
234+
Tensor<T,3,2> As; As.iota(1);
235+
Tensor<T,3> bs; bs.fill(1);
236+
Tensor<T,2> cs; cs.fill(2);
237+
238+
assert((einsum<Index<i,j>,Index<j>>(As,cs)).sum() - 42. < Tol);
239+
assert((einsum<Index<i,j>,Index<i>>(As,bs)).sum() - 21. < Tol);
240+
assert((einsum<Index<i>,Index<i,j>>(bs,As)).sum() - 21. < Tol);
241+
assert((einsum<Index<j>,Index<i,j>>(cs,As)).sum() - 42. < Tol);
242+
}
243+
233244
print(FGRN(BOLD("All tests passed successfully")));
234245
}
235246

0 commit comments

Comments
 (0)