@@ -119,8 +119,9 @@ tile_load(tile_t& tile, payload_t& payload) {
119
119
static constexpr uint32_t max_load_width_in_elem =
120
120
load_store_attr::max_load_width_in_bytes / sizeof (dtype);
121
121
122
- // static constexpr uint32_t max_trans_load_height_in_elem =
123
- // load_store_attr::max_trans_load_height_in_elem;
122
+ static constexpr uint32_t max_trans_load_height_in_elem =
123
+ load_store_attr::max_trans_load_height_in_elem;
124
+
124
125
static constexpr uint32_t max_load_height_in_elem =
125
126
load_store_attr::max_load_height_in_elem;
126
127
@@ -130,6 +131,22 @@ tile_load(tile_t& tile, payload_t& payload) {
130
131
static constexpr uint32_t elems_per_reg =
131
132
register_bytes_t <arch_tag>::reg_in_bytes / sizeof (dtype);
132
133
134
+ static constexpr uint32_t max_ld_blk_width_in_elem =
135
+ trans ? max_trans_load_width_in_elem : max_load_width_in_elem;
136
+
137
+ static constexpr uint32_t max_ld_blk_height_in_elem =
138
+ trans ? max_trans_load_height_in_elem : max_load_height_in_elem;
139
+
140
+ static constexpr uint32_t ld_blk_width =
141
+ std::min (
142
+ mem_transpose ? block_size_y : block_size_x,
143
+ max_ld_blk_width_in_elem) /
144
+ scale_factor;
145
+ static constexpr uint32_t ld_blk_height = std::min (
146
+ mem_transpose ? block_size_x : block_size_y, max_ld_blk_height_in_elem);
147
+
148
+
149
+
133
150
static constexpr uint32_t ld_blk_size_y_limit =
134
151
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
135
152
static constexpr uint32_t ld_blk_size_y = reg_transpose
@@ -211,12 +228,11 @@ tile_load(tile_t& tile, payload_t& payload) {
211
228
scale_factor;
212
229
uint32_t address_offset_y =
213
230
mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
231
+
214
232
reg_tmp.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
215
233
native_type_t <load_dtype>,
216
- (trans ? ld_blk_size_y : block_size_x) / scale_factor,
217
- (trans ? block_size_x : ld_blk_size_y),
218
- // block_size_x / scale_factor,
219
- // ld_blk_size_y,
234
+ ld_blk_width,
235
+ ld_blk_height,
220
236
arr_len,
221
237
trans,
222
238
mem_transform,
@@ -261,11 +277,6 @@ tile_load(tile_t& tile, payload_t& payload) {
261
277
(mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
262
278
constexpr uint8_t block_height =
263
279
mem_transpose ? block_size_x : remained_blk_size_y;
264
- // constexpr uint32_t block_widthx_widthy_arrlen =
265
- // (block_width - 1) | ((block_height - 1) << 8);
266
- // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
267
- // tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
268
-
269
280
reg_blk.xetla_select <load_elems, 1 >(remained_start)
270
281
.xetla_format <native_type_t <load_dtype>>() = xetla_load_global<
271
282
native_type_t <load_dtype>,
@@ -283,15 +294,6 @@ tile_load(tile_t& tile, payload_t& payload) {
283
294
payload.surface_pitch ,
284
295
payload.offset_x + offset_x / scale_factor,
285
296
payload.offset_y + offset_y + remained_start_y);
286
-
287
- // xetla_tload_global<
288
- // load_dtype,
289
- // (load_elems / scale_factor),
290
- // L1,
291
- // L2,
292
- // trans,
293
- // mem_transform,
294
- // arch_tag>(tdesc);
295
297
}
296
298
}
297
299
}
@@ -304,15 +306,7 @@ tile_load(tile_t& tile, payload_t& payload) {
304
306
(!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
305
307
? ld_blk_size_y_limit
306
308
: remained_size_y;
307
- // auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
308
- // num_block_y * num_block_x, 0);
309
- // detail::reset_tile_desc_core<
310
- // num_block_x,
311
- // block_size_x,
312
- // remained_ld_blk_size_y,
313
- // scale_factor,
314
- // arr_len,
315
- // mem_transpose>(payload_row);
309
+
316
310
#pragma unroll
317
311
for (uint32_t j = 0 ; j < num_block_x; j += arr_len) {
318
312
int32_t offset_x = j * block_size_x;
0 commit comments