Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 20c2056

Browse files
committed
save
1 parent 90662a8 commit 20c2056

File tree

1 file changed

+23
-29
lines changed

1 file changed

+23
-29
lines changed

include/subgroup/tile/impl/load_xe.hpp

+23-29
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ tile_load(tile_t& tile, payload_t& payload) {
119119
static constexpr uint32_t max_load_width_in_elem =
120120
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
121121

122-
// static constexpr uint32_t max_trans_load_height_in_elem =
123-
// load_store_attr::max_trans_load_height_in_elem;
122+
static constexpr uint32_t max_trans_load_height_in_elem =
123+
load_store_attr::max_trans_load_height_in_elem;
124+
124125
static constexpr uint32_t max_load_height_in_elem =
125126
load_store_attr::max_load_height_in_elem;
126127

@@ -130,6 +131,22 @@ tile_load(tile_t& tile, payload_t& payload) {
130131
static constexpr uint32_t elems_per_reg =
131132
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
132133

134+
static constexpr uint32_t max_ld_blk_width_in_elem =
135+
trans ? max_trans_load_width_in_elem : max_load_width_in_elem;
136+
137+
static constexpr uint32_t max_ld_blk_height_in_elem =
138+
trans ? max_trans_load_height_in_elem : max_load_height_in_elem;
139+
140+
static constexpr uint32_t ld_blk_width =
141+
std::min(
142+
mem_transpose ? block_size_y : block_size_x,
143+
max_ld_blk_width_in_elem) /
144+
scale_factor;
145+
static constexpr uint32_t ld_blk_height = std::min(
146+
mem_transpose ? block_size_x : block_size_y, max_ld_blk_height_in_elem);
147+
148+
149+
133150
static constexpr uint32_t ld_blk_size_y_limit =
134151
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135152
static constexpr uint32_t ld_blk_size_y = reg_transpose
@@ -211,12 +228,11 @@ tile_load(tile_t& tile, payload_t& payload) {
211228
scale_factor;
212229
uint32_t address_offset_y =
213230
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
231+
214232
reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
215233
native_type_t<load_dtype>,
216-
(trans ? ld_blk_size_y : block_size_x) / scale_factor,
217-
(trans ? block_size_x : ld_blk_size_y),
218-
// block_size_x / scale_factor,
219-
// ld_blk_size_y,
234+
ld_blk_width,
235+
ld_blk_height,
220236
arr_len,
221237
trans,
222238
mem_transform,
@@ -261,11 +277,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261277
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262278
constexpr uint8_t block_height =
263279
mem_transpose ? block_size_x : remained_blk_size_y;
264-
// constexpr uint32_t block_widthx_widthy_arrlen =
265-
// (block_width - 1) | ((block_height - 1) << 8);
266-
// gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267-
// tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268-
269280
reg_blk.xetla_select<load_elems, 1>(remained_start)
270281
.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
271282
native_type_t<load_dtype>,
@@ -283,15 +294,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283294
payload.surface_pitch,
284295
payload.offset_x + offset_x / scale_factor,
285296
payload.offset_y + offset_y + remained_start_y);
286-
287-
// xetla_tload_global<
288-
// load_dtype,
289-
// (load_elems / scale_factor),
290-
// L1,
291-
// L2,
292-
// trans,
293-
// mem_transform,
294-
// arch_tag>(tdesc);
295297
}
296298
}
297299
}
@@ -304,15 +306,7 @@ tile_load(tile_t& tile, payload_t& payload) {
304306
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305307
? ld_blk_size_y_limit
306308
: remained_size_y;
307-
// auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308-
// num_block_y * num_block_x, 0);
309-
// detail::reset_tile_desc_core<
310-
// num_block_x,
311-
// block_size_x,
312-
// remained_ld_blk_size_y,
313-
// scale_factor,
314-
// arr_len,
315-
// mem_transpose>(payload_row);
309+
316310
#pragma unroll
317311
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
318312
int32_t offset_x = j * block_size_x;

0 commit comments

Comments
 (0)