Skip to content

Commit 4ff2ea0

Browse files
committed
Fix a serious bug in strided_contraction for cases where the second/last tensor disappears
1 parent 70838d2 commit 4ff2ea0

File tree

5 files changed

+42
-9
lines changed

5 files changed

+42
-9
lines changed

meta/einsum_meta.h

+6
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ template<typename T, size_t ...Idx0, size_t ...Idx1, size_t...Rest>
179179
struct is_vectorisable<Index<Idx0...>,Index<Idx1...>,Tensor<T,Rest...>> {
180180
static constexpr size_t fastest_changing_index = get_value<sizeof...(Rest),Rest...>::value;
181181
static constexpr size_t idx[sizeof...(Idx0)] = {Idx0...};
182+
static constexpr bool does_2nd_tensor_disappear = ((int)no_of_unique<Idx0...,Idx1...>::value == (int)sizeof...(Idx0) - (int)sizeof...(Idx1));
182183
static constexpr bool last_index_contracted = contains(idx,get_value<sizeof...(Idx1),Idx1...>::value);
184+
static constexpr bool is_reducible = does_2nd_tensor_disappear && last_index_contracted;
183185
static constexpr bool value = (!last_index_contracted) && (fastest_changing_index % get_vector_size<T,FASTOR_SSE>::size==0);
184186
static constexpr bool sse_vectorisability = (!last_index_contracted) &&
185187
(fastest_changing_index % get_vector_size<T,FASTOR_SSE>::size==0 && fastest_changing_index % get_vector_size<T,FASTOR_AVX>::size!=0);
@@ -196,7 +198,9 @@ template<size_t ...Idx0, size_t ...Idx1, size_t...Rest>
196198
struct is_vectorisable<Index<Idx0...>,Index<Idx1...>,Tensor<float,Rest...>> {
197199
static constexpr size_t fastest_changing_index = get_value<sizeof...(Rest),Rest...>::value;
198200
static constexpr size_t idx[sizeof...(Idx0)] = {Idx0...};
201+
static constexpr bool does_2nd_tensor_disappear = ((int)no_of_unique<Idx0...,Idx1...>::value == (int)sizeof...(Idx0) - (int)sizeof...(Idx1));
199202
static constexpr bool last_index_contracted = contains(idx,get_value<sizeof...(Idx1),Idx1...>::value);
203+
static constexpr bool is_reducible = does_2nd_tensor_disappear && last_index_contracted;
200204
static constexpr bool value = (!last_index_contracted) && (fastest_changing_index % 4==0);
201205
static constexpr bool sse_vectorisability = (!last_index_contracted) && (fastest_changing_index % 4==0 && fastest_changing_index % 8!=0);
202206
static constexpr bool avx_vectorisability = (!last_index_contracted) && (fastest_changing_index % 4==0 && fastest_changing_index % 8==0);
@@ -210,7 +214,9 @@ template<size_t ...Idx0, size_t ...Idx1, size_t...Rest>
210214
struct is_vectorisable<Index<Idx0...>,Index<Idx1...>,Tensor<double,Rest...>> {
211215
static constexpr size_t fastest_changing_index = get_value<sizeof...(Rest),Rest...>::value;
212216
static constexpr size_t idx[sizeof...(Idx0)] = {Idx0...};
217+
static constexpr bool does_2nd_tensor_disappear = ((int)no_of_unique<Idx0...,Idx1...>::value == (int)sizeof...(Idx0) - (int)sizeof...(Idx1));
213218
static constexpr bool last_index_contracted = contains(idx,get_value<sizeof...(Idx1),Idx1...>::value);
219+
static constexpr bool is_reducible = does_2nd_tensor_disappear && last_index_contracted;
214220
static constexpr bool value = (!last_index_contracted) && (fastest_changing_index % 2==0);
215221
static constexpr bool sse_vectorisability = (!last_index_contracted) && (fastest_changing_index % 2==0 && fastest_changing_index % 4!=0);
216222
static constexpr bool avx_vectorisability = (!last_index_contracted) && (fastest_changing_index % 2==0 && fastest_changing_index % 4==0);

simd_vector/simd_vector_base.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ struct SIMDVector {
130130
}
131131
template<typename U, typename ... Args>
132132
FASTOR_INLINE void set(U first, Args ... args) {
133-
unused(first);
133+
T arr[Size] = {first,args...};
134+
std::reverse_copy(arr, arr+Size, value);
134135
// Relax this restriction
135136
// static_assert(sizeof...(args)==1,"CANNOT SET VECTOR WITH VALUES DUE TO ABI CONSIDERATION");
136137
}

tensor_algebra/einsum.h

+8-7
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b)
6565

6666
// Dispatch to the right routine
6767
using vectorisability = is_vectorisable<Index_I,Index_J,Tensor<T,Rest1...>>;
68-
constexpr bool is_reducible = vectorisability::last_index_contracted;
68+
// constexpr bool is_reducible = vectorisability::last_index_contracted;
69+
constexpr bool is_reducible = vectorisability::is_reducible;
6970
if (is_reducible) {
7071
return extractor_reducible_contract<Index_I,Index_J>::contract_impl(a,b);
7172
}
@@ -94,7 +95,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b, const Tens
9495
// Dispatch to the right routine
9596
using Index0 = typename concat_<Index_I,Index_J>::type;
9697
using vectorisability = is_vectorisable<Index0,Index_K,Tensor<T,Rest2...>>;
97-
constexpr bool is_reducible = vectorisability::last_index_contracted;
98+
constexpr bool is_reducible = vectorisability::is_reducible;
9899
if (is_reducible) {
99100
return extractor_strided_contract<Index_I,Index_J,Index_K>::contract_impl(a,b,c);
100101
}
@@ -117,7 +118,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b, const Tens
117118
// Dispatch to the right routine
118119
using Index0 = typename concat_<Index_I,Index_J,Index_K>::type;
119120
using vectorisability = is_vectorisable<Index0,Index_L,Tensor<T,Rest3...>>;
120-
constexpr bool is_reducible = vectorisability::last_index_contracted;
121+
constexpr bool is_reducible = vectorisability::is_reducible;
121122
if (is_reducible) {
122123
return extractor_strided_contract_4<Index_I,Index_J,Index_K,Index_L>::contract_impl(a,b,c,d);
123124
}
@@ -142,7 +143,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
142143
// Dispatch to the right routine
143144
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L>::type;
144145
using vectorisability = is_vectorisable<Index0,Index_M,Tensor<T,Rest4...>>;
145-
constexpr bool is_reducible = vectorisability::last_index_contracted;
146+
constexpr bool is_reducible = vectorisability::is_reducible;
146147
if (is_reducible) {
147148
return extractor_strided_contract_5<Index_I,Index_J,Index_K,Index_L,Index_M>::contract_impl(a,b,c,d,e);
148149
}
@@ -166,7 +167,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
166167
// Dispatch to the right routine
167168
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L,Index_M>::type;
168169
using vectorisability = is_vectorisable<Index0,Index_N,Tensor<T,Rest5...>>;
169-
constexpr bool is_reducible = vectorisability::last_index_contracted;
170+
constexpr bool is_reducible = vectorisability::is_reducible;
170171
if (is_reducible) {
171172
return extractor_strided_contract_6<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N>::contract_impl(a,b,c,d,e,f);
172173
}
@@ -192,7 +193,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
192193
// Dispatch to the right routine
193194
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N>::type;
194195
using vectorisability = is_vectorisable<Index0,Index_O,Tensor<T,Rest6...>>;
195-
constexpr bool is_reducible = vectorisability::last_index_contracted;
196+
constexpr bool is_reducible = vectorisability::is_reducible;
196197
if (is_reducible) {
197198
return extractor_strided_contract_7<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N,Index_O>::contract_impl(a,b,c,d,e,f,g);
198199
}
@@ -219,7 +220,7 @@ auto einsum(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b,
219220
// Dispatch to the right routine
220221
using Index0 = typename concat_<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N,Index_O>::type;
221222
using vectorisability = is_vectorisable<Index0,Index_P,Tensor<T,Rest7...>>;
222-
constexpr bool is_reducible = vectorisability::last_index_contracted;
223+
constexpr bool is_reducible = vectorisability::is_reducible;
223224
if (is_reducible) {
224225
return extractor_strided_contract_8<Index_I,Index_J,Index_K,Index_L,Index_M,Index_N,Index_O,Index_P>::contract_impl(a,b,c,d,e,f,g,h);
225226
}

tensor_algebra/strided_contraction.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ struct extractor_reducible_contract<Index<Idx0...>, Index<Idx1...>> {
282282

283283
template<class Index_I, class Index_J,
284284
typename T, size_t ... Rest0, size_t ... Rest1>
285-
auto strided_contraction(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b)
285+
FASTOR_INLINE auto strided_contraction(const Tensor<T,Rest0...> &a, const Tensor<T,Rest1...> &b)
286286
-> decltype(extractor_reducible_contract<Index_I,Index_J>::contract_impl(a,b)) {
287287
return extractor_reducible_contract<Index_I,Index_J>::contract_impl(a,b);
288288
}

tests/test_einsum.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,31 @@ void run() {
241241
assert((einsum<Index<j>,Index<i,j>>(cs,As)).sum() - 42. < Tol);
242242
}
243243

244+
{
245+
// Test strided_contraction when second tensor disappears
246+
Tensor<T,4,4,4> a; a.iota(1);
247+
Tensor<T,4,4> b; b.iota(1);
248+
249+
Tensor<T,4> c1 = einsum<Index<i,j,k>,Index<j,k> >(a,b);
250+
Tensor<T,4> c2 = einsum<Index<i,j,k>,Index<i,k> >(a,b);
251+
Tensor<T,4> c3 = einsum<Index<i,j,k>,Index<i,j> >(a,b);
252+
253+
assert (abs(c1(0) - 1496.) < Tol);
254+
assert (abs(c1(1) - 3672.) < Tol);
255+
assert (abs(c1(2) - 5848.) < Tol);
256+
assert (abs(c1(3) - 8024.) < Tol);
257+
258+
assert (abs(c2(0) - 4904.) < Tol);
259+
assert (abs(c2(1) - 5448.) < Tol);
260+
assert (abs(c2(2) - 5992.) < Tol);
261+
assert (abs(c2(3) - 6536.) < Tol);
262+
263+
assert (abs(c3(0) - 5576.) < Tol);
264+
assert (abs(c3(1) - 5712.) < Tol);
265+
assert (abs(c3(2) - 5848.) < Tol);
266+
assert (abs(c3(3) - 5984.) < Tol);
267+
}
268+
244269
print(FGRN(BOLD("All tests passed successfully")));
245270
}
246271

0 commit comments

Comments
 (0)