Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit af22c33

Browse files
committed
fix perchanel fp16 load
1 parent af20dba commit af22c33

File tree

3 files changed

+9
-47
lines changed

3 files changed

+9
-47
lines changed

include/common/core/memory.hpp

+3-18
Original file line numberDiff line numberDiff line change
@@ -476,24 +476,9 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
476476
Y);
477477
return ret.xetla_format<T>();
478478
} else if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
479-
constexpr auto scale_factor = sizeof(uint32_t) / sizeof(T);
480-
xetla_vector<uint32_t, N> ret = xetla_load_global<
481-
uint32_t,
482-
BlockWidth,
483-
BlockHeight,
484-
NBlocks,
485-
Transposed,
486-
Transformed,
487-
L1H,
488-
L2H>(
489-
reinterpret_cast<const uint32_t*>(Ptr),
490-
SurfaceWidth,
491-
SurfaceHeight,
492-
SurfacePitch,
493-
X / scale_factor,
494-
Y);
495-
return ret.xetla_format<T>().xetla_select<N, scale_factor>(
496-
X % scale_factor);
479+
xetla_vector<uint32_t, BlockHeight> byte_offsets =
480+
xetla_vector_gen<uint32_t, BlockHeight>(0, SurfacePitch);
481+
return xetla_load_global<T, N, BlockWidth, L1H, L2H>(Ptr, byte_offsets);
497482
} else {
498483
return __ESIMD_ENS::lsc_load_2d<
499484
T,

include/subgroup/tile/impl/load_xe.hpp

+6-15
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,10 @@ tile_load(tile_t& tile, payload_t& payload) {
236236
reg_tmp
237237
.xetla_format<
238238
native_type_t<load_dtype>,
239-
block_size_x / scale_factor,
239+
ld_blk_width / scale_factor,
240240
ld_blk_height>()
241241
.xetla_select<
242-
block_size_x / scale_factor,
242+
ld_blk_width / scale_factor,
243243
1,
244244
ld_blk_size_y,
245245
1>(0, 0);
@@ -297,9 +297,9 @@ tile_load(tile_t& tile, payload_t& payload) {
297297
// xetla_tdescriptor tdesc = payload_row.row(j);
298298
auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
299299
processed_elems + j * remained_block_elems);
300-
// constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301-
// ? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302-
// : remained_ld_blk_size_y;
300+
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
301+
? detail::getNextPowerOf2<remained_ld_blk_size_y>()
302+
: remained_ld_blk_size_y;
303303
constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
304304
xetla_vector<dtype, tmp_size> reg_tmp;
305305
#pragma unroll
@@ -311,7 +311,7 @@ tile_load(tile_t& tile, payload_t& payload) {
311311
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
312312
native_type_t<load_dtype>,
313313
block_size_x / scale_factor,
314-
ld_blk_height,
314+
remained_ld_blk_size_y,
315315
arr_len,
316316
trans,
317317
mem_transform,
@@ -325,15 +325,6 @@ tile_load(tile_t& tile, payload_t& payload) {
325325
payload.offset_x + offset_x / scale_factor,
326326
payload.offset_y + num_block_y * block_size_y +
327327
ii * remained_ld_blk_size_y);
328-
// xetla_tload_global<
329-
// load_dtype,
330-
// (ld_blk_height * block_size_x * arr_len / scale_factor),
331-
// L1,
332-
// L2,
333-
// trans,
334-
// mem_transform,
335-
// arch_tag>(tdesc);
336-
337328
if constexpr (reg_transpose && trans) {
338329
reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
339330
.xetla_format<native_type_t<load_dtype>>() =

include/subgroup/tile/impl/payload_xe.hpp

-14
Original file line numberDiff line numberDiff line change
@@ -1841,20 +1841,6 @@ struct prefetch_payload_t<
18411841
return channel >= 32 ? 32 : channel >= 16 ? 16 : channel >= 8 ? 8 : 1;
18421842
}
18431843

1844-
static constexpr uint32_t num_channel = select_channel(
1845-
std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
1846-
1847-
static constexpr uint32_t max_channel =
1848-
max_prefetch_vec_len / (vector_size * sizeof(prefetch_dtype));
1849-
1850-
static constexpr uint32_t select_channel(const uint32_t channel) {
1851-
return (channel >= load_store_attr::max_channel_num)
1852-
? load_store_attr::max_channel_num
1853-
: channel >= 16 ? 16
1854-
: channel >= 8 ? 8
1855-
: 1;
1856-
}
1857-
18581844
static constexpr uint32_t num_channel = select_channel(
18591845
std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
18601846

0 commit comments

Comments
 (0)