@@ -376,6 +376,82 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
376
376
#endif
377
377
378
378
379
+
380
+ // matmul dispatcher for 2nd order tensors (matrix-matrix)
381
+ // also includes matrix-vector and vector-matrix when vector is of size
382
+ // nx1 or 1xn
383
+ template <class Ind0 , class Ind1 ,
384
+ typename T, size_t I, size_t J, size_t K,
385
+ typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
386
+ Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
387
+ Ind0::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ] &&
388
+ Ind0::_IndexHolder[1 ] != Ind1::_IndexHolder[1 ] &&
389
+ Ind0::_IndexHolder[0 ] != Ind1::_IndexHolder[1 ],bool >::type = 0 >
390
+ FASTOR_INLINE Tensor<T,I,K>
391
+ einsum (const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
392
+ Tensor<T,I,K> out;
393
+ _matmul<T,I,J,K>(a.data (),b.data (),out.data ());
394
+ return out;
395
+ }
396
+
397
+
398
+ // matmul dispatcher for matrix-vector
399
+ template <class Ind0 , class Ind1 ,
400
+ typename T, size_t I, size_t J,
401
+ typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
402
+ Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
403
+ Ind0::_IndexHolder[0 ] != Ind1::_IndexHolder[0 ]
404
+ ,bool >::type = 0 >
405
+ FASTOR_INLINE Tensor<T,I>
406
+ einsum (const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
407
+ Tensor<T,I> out;
408
+ _matmul<T,I,J,1 >(a.data (),b.data (),out.data ());
409
+ return out;
410
+ }
411
+
412
+ // matmul dispatcher for matrix-vector
413
+ template <class Ind0 , class Ind1 ,
414
+ typename T, size_t I, size_t J,
415
+ typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
416
+ Ind0::_IndexHolder[0 ] == Ind1::_IndexHolder[0 ] &&
417
+ Ind0::_IndexHolder[1 ] != Ind1::_IndexHolder[0 ],bool >::type = 0 >
418
+ FASTOR_INLINE Tensor<T,J>
419
+ einsum (const Tensor<T,I,J> &a, const Tensor<T,I> &b) {
420
+ Tensor<T,J> out;
421
+ _matmul<T,1 ,I,J>(b.data (),a.data (),out.data ());
422
+ return out;
423
+ }
424
+
425
+
426
+ // matmul dispatcher for vector-matrix
427
+ template <class Ind0 , class Ind1 ,
428
+ typename T, size_t I, size_t J,
429
+ typename std::enable_if<Ind1::NoIndices==2 && Ind0::NoIndices==1 &&
430
+ Ind1::_IndexHolder[0 ] == Ind0::_IndexHolder[0 ] &&
431
+ Ind1::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ],bool >::type = 0 >
432
+ FASTOR_INLINE Tensor<T,J>
433
+ einsum (const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
434
+ Tensor<T,J> out;
435
+ _matmul<T,1 ,I,J>(a.data (),b.data (),out.data ());
436
+ return out;
437
+ }
438
+
439
+
440
+ // matmul dispatcher for vector-matrix
441
+ template <class Ind0 , class Ind1 ,
442
+ typename T, size_t I, size_t J,
443
+ typename std::enable_if<Ind1::NoIndices==2 && Ind0::NoIndices==1 &&
444
+ Ind1::_IndexHolder[1 ] == Ind0::_IndexHolder[0 ] &&
445
+ Ind1::_IndexHolder[0 ] != Ind0::_IndexHolder[0 ],bool >::type = 0 >
446
+ FASTOR_INLINE Tensor<T,I>
447
+ einsum (const Tensor<T,J> &a, const Tensor<T,I,J> &b) {
448
+ Tensor<T,I> out;
449
+ _matmul<T,I,J,1 >(b.data (),a.data (),out.data ());
450
+ return out;
451
+ }
452
+
453
+
454
+
379
455
#ifdef __AVX__
380
456
381
457
// Specific overloads
@@ -444,52 +520,6 @@ einsum(const Tensor<T,I,J> & a, const Tensor<T,K,L> &b) {
444
520
}
445
521
446
522
447
- // matmul dispatcher for 2nd order tensors (matrix-matrix)
448
- // also includes matrix-vector and vector-matrix when vector is of size
449
- // nx1 or 1xn
450
- template <class Ind0 , class Ind1 ,
451
- typename T, size_t I, size_t J, size_t K,
452
- typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==2 &&
453
- Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
454
- Ind0::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ] &&
455
- Ind0::_IndexHolder[1 ] != Ind1::_IndexHolder[1 ] &&
456
- Ind0::_IndexHolder[0 ] != Ind1::_IndexHolder[1 ],bool >::type = 0 >
457
- FASTOR_INLINE Tensor<T,I,K>
458
- einsum (const Tensor<T,I,J> &a, const Tensor<T,J,K> &b) {
459
- Tensor<T,I,K> out;
460
- _matmul<T,I,J,K>(a.data (),b.data (),out.data ());
461
- return out;
462
- }
463
-
464
-
465
- // matmul dispatcher for matrix-vector
466
- template <class Ind0 , class Ind1 ,
467
- typename T, size_t I, size_t J,
468
- typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
469
- Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
470
- Ind0::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ],bool >::type = 0 >
471
- FASTOR_INLINE Tensor<T,I>
472
- einsum (const Tensor<T,I,J> &a, const Tensor<T,J> &b) {
473
- Tensor<T,I> out;
474
- _matmul<T,I,J,1 >(a.data (),b.data (),out.data ());
475
- return out;
476
- }
477
-
478
-
479
- // matmul dispatcher for vector-matrix
480
- template <class Ind0 , class Ind1 ,
481
- typename T, size_t I, size_t J,
482
- typename std::enable_if<Ind0::NoIndices==2 && Ind1::NoIndices==1 &&
483
- Ind0::_IndexHolder[1 ] == Ind1::_IndexHolder[0 ] &&
484
- Ind0::_IndexHolder[1 ] != Ind0::_IndexHolder[0 ],bool >::type = 0 >
485
- FASTOR_INLINE Tensor<T,J>
486
- einsum (const Tensor<T,I> &a, const Tensor<T,I,J> &b) {
487
- Tensor<T,J> out;
488
- _matmul<T,I,J,1 >(a.data (),b.data (),out.data ());
489
- return out;
490
- }
491
-
492
-
493
523
// The following two overloads are provided for an external use case
494
524
// A_ijk*B_kl
495
525
template <class Ind0 , class Ind1 ,
0 commit comments