Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 82e9ce2

Browse files
committed
fix perchanel fp16 load
1 parent 4b56d54 commit 82e9ce2

File tree

2 files changed

+9
-33
lines changed

2 files changed

+9
-33
lines changed

include/common/core/memory.hpp

+3-18
Original file line numberDiff line numberDiff line change
@@ -476,24 +476,9 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
476476
Y);
477477
return ret.xetla_format<T>();
478478
} else if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
479-
constexpr auto scale_factor = sizeof(uint32_t) / sizeof(T);
480-
xetla_vector<uint32_t, N> ret = xetla_load_global<
481-
uint32_t,
482-
BlockWidth,
483-
BlockHeight,
484-
NBlocks,
485-
Transposed,
486-
Transformed,
487-
L1H,
488-
L2H>(
489-
reinterpret_cast<const uint32_t*>(Ptr),
490-
SurfaceWidth,
491-
SurfaceHeight,
492-
SurfacePitch,
493-
X / scale_factor,
494-
Y);
495-
return ret.xetla_format<T>().xetla_select<N, scale_factor>(
496-
X % scale_factor);
479+
xetla_vector<uint32_t, BlockHeight> byte_offsets =
480+
xetla_vector_gen<uint32_t, BlockHeight>(0, SurfacePitch);
481+
return xetla_load_global<T, N, BlockWidth, L1H, L2H>(Ptr, byte_offsets);
497482
} else {
498483
return __ESIMD_ENS::lsc_load_2d<
499484
T,

include/subgroup/tile/impl/load_xe.hpp

+6-15
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ tile_load(tile_t& tile, payload_t& payload) {
236236
reg_tmp
237237
.xetla_format<
238238
native_type_t<load_dtype>,
239-
block_size_x / scale_factor,
239+
ld_blk_width / scale_factor,
240240
ld_blk_height>()
241241
.xetla_select<
242-
block_size_x / scale_factor,
242+
ld_blk_width / scale_factor,
243243
1,
244244
ld_blk_size_y,
245245
1>(0, 0);
@@ -297,9 +297,9 @@ tile_load(tile_t& tile, payload_t& payload) {
297297
// xetla_tdescriptor tdesc = payload_row.row(j);
298298
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
299299
processed_elems + j * remained_block_elems);
300-
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301-
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302-
// : remained_ld_blk_size_y;
300+
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301+
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302+
: remained_ld_blk_size_y;
303303
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
304304
xetla_vector<dtype, tmp_size> reg_tmp;
305305
#pragma unroll
@@ -311,7 +311,7 @@ tile_load(tile_t& tile, payload_t& payload) {
311311
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
312312
native_type_t<load_dtype>,
313313
block_size_x / scale_factor,
314-
ld_blk_height,
314+
remained_ld_blk_size_y,
315315
arr_len,
316316
trans,
317317
mem_transform,
@@ -325,15 +325,6 @@ tile_load(tile_t& tile, payload_t& payload) {
325325
payload.offset_x + offset_x / scale_factor,
326326
payload.offset_y + num_block_y * block_size_y +
327327
ii * remained_ld_blk_size_y);
328-
// xetla_tload_global<
329-
// load_dtype,
330-
// (ld_blk_height * block_size_x * arr_len / scale_factor),
331-
// L1,
332-
// L2,
333-
// trans,
334-
// mem_transform,
335-
// arch_tag>(tdesc);
336-
337328
if constexpr (reg_transpose && trans) {
338329
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
339330
.xetla_format<native_type_t<load_dtype>>() =

0 commit comments

Comments
 (0)