|
26 | 26 | namespace gpu::xetla::subgroup {
|
27 | 27 |
|
28 | 28 | namespace detail {
|
29 |
| -template <typename tile_t, typename payload_t, bool is_lsc_gather_ = false> |
| 29 | +template <typename tile_t, typename payload_t, bool is_lsc_gather_ = true> |
30 | 30 | struct check_load_type {
|
31 | 31 | static constexpr bool is_lsc_gather = is_lsc_gather_;
|
32 | 32 | static constexpr bool is_global_block_2d =
|
@@ -465,73 +465,73 @@ tile_load(tile_t& tile, payload_t& payload) {
|
465 | 465 | uint32_t offset_x = j * tile_desc::block_size_x;
|
466 | 466 | auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
|
467 | 467 | (i * tile_desc::num_block_x + j) * tile_desc::block_elems);
|
468 |
| - // #pragma unroll |
469 |
| - // for (uint32_t sub_block_offset = 0; sub_block_offset < |
470 |
| - // (payload_t::mem_transpose ? tile_desc::block_size_x |
471 |
| - // : tile_desc::block_size_y); |
472 |
| - // sub_block_offset += num_channel) { |
473 |
| - uint32_t sub_block_offset = 0; |
474 |
| - xetla_vector<load_dtype, load_elems> reg_tmp = 0; |
475 |
| - uint32_t address_offset = payload_t::mem_transpose |
476 |
| - ? (offset_x + sub_block_offset) * payload.pitch_in_bytes + |
477 |
| - offset_y * sizeof(dtype) |
478 |
| - : offset_x * sizeof(dtype) + |
479 |
| - (offset_y + sub_block_offset) * payload.pitch_in_bytes; |
480 |
| - xetla_mask<num_channel> pred = 1; |
481 |
| - if constexpr (num_channel > 1) { |
482 |
| - // For SDP load, need pred |
483 |
| - const uint32_t sub_block_offset_x = payload.base_x + offset_x + |
484 |
| - (payload_t::mem_transpose ? sub_block_offset : 0); |
485 |
| - const uint32_t sub_block_offset_y = payload.base_y + offset_y + |
486 |
| - (payload_t::mem_transpose ? 0 : sub_block_offset); |
487 |
| - const auto offset_ch_dim = |
488 |
| - payload_t::trans ? sub_block_offset_x : sub_block_offset_y; |
489 |
| - const auto size_ch_dim = |
490 |
| - payload_t::trans ? payload.width_in_elems : payload.height_in_elems; |
491 |
| - |
492 |
| - pred = offset_ch_dim + num_channel > size_ch_dim |
493 |
| - ? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) < |
494 |
| - size_ch_dim) |
495 |
| - : 1; |
496 |
| - } |
497 |
| - reg_tmp = xetla_load_global< |
498 |
| - load_dtype, |
499 |
| - payload_t::simd_exec_size, |
500 |
| - data_size::default_size, |
501 |
| - L1, |
502 |
| - L2, |
503 |
| - payload_t::num_channel>( |
504 |
| - payload.base_ptr, |
505 |
| - payload.channel_offset + payload.base_offset + address_offset, |
506 |
| - pred); |
507 |
| - |
508 |
| - if constexpr ( |
509 |
| - payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { |
510 |
| - xetla_vector<load_dtype, load_elems> reg_tmp_trans; |
511 | 468 | #pragma unroll
|
512 |
| - for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { |
513 |
| - if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix |
514 |
| - reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>( |
515 |
| - iii * payload_t::simd_exec_size) = |
516 |
| - reg_tmp.xetla_select< |
517 |
| - payload_t::simd_exec_size, |
518 |
| - payload_t::num_channel>(iii); |
519 |
| - else // TODO (dingyi): Delete after driver fix |
520 |
| - reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>( |
521 |
| - iii * payload_t::simd_exec_size) = 0; |
| 469 | + for (uint32_t sub_block_offset = 0; sub_block_offset < |
| 470 | + (payload_t::mem_transpose ? tile_desc::block_size_x |
| 471 | + : tile_desc::block_size_y); |
| 472 | + sub_block_offset += num_channel) { |
| 473 | + // uint32_t sub_block_offset = 0; |
| 474 | + xetla_vector<load_dtype, load_elems> reg_tmp = 0; |
| 475 | + uint32_t address_offset = payload_t::mem_transpose |
| 476 | + ? (offset_x + sub_block_offset) * payload.pitch_in_bytes + |
| 477 | + offset_y * sizeof(dtype) |
| 478 | + : offset_x * sizeof(dtype) + |
| 479 | + (offset_y + sub_block_offset) * payload.pitch_in_bytes; |
| 480 | + xetla_mask<num_channel> pred = 1; |
| 481 | + if constexpr (num_channel > 1) { |
| 482 | + // For SDP load, need pred |
| 483 | + const uint32_t sub_block_offset_x = payload.base_x + offset_x + |
| 484 | + (payload_t::mem_transpose ? sub_block_offset : 0); |
| 485 | + const uint32_t sub_block_offset_y = payload.base_y + offset_y + |
| 486 | + (payload_t::mem_transpose ? 0 : sub_block_offset); |
| 487 | + const auto offset_ch_dim = |
| 488 | + payload_t::trans ? sub_block_offset_x : sub_block_offset_y; |
| 489 | + const auto size_ch_dim = payload_t::trans ? payload.width_in_elems |
| 490 | + : payload.height_in_elems; |
| 491 | + |
| 492 | + pred = offset_ch_dim + num_channel > size_ch_dim |
| 493 | + ? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) < |
| 494 | + size_ch_dim) |
| 495 | + : 1; |
| 496 | + } |
| 497 | + reg_tmp = xetla_load_global< |
| 498 | + load_dtype, |
| 499 | + payload_t::simd_exec_size, |
| 500 | + data_size::default_size, |
| 501 | + L1, |
| 502 | + L2, |
| 503 | + num_channel>( |
| 504 | + payload.base_ptr, |
| 505 | + payload.channel_offset + payload.base_offset + address_offset, |
| 506 | + pred); |
| 507 | + |
| 508 | + if constexpr ( |
| 509 | + payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) { |
| 510 | + xetla_vector<load_dtype, load_elems> reg_tmp_trans; |
| 511 | +#pragma unroll |
| 512 | + for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) { |
| 513 | + if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix |
| 514 | + reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>( |
| 515 | + iii * payload_t::simd_exec_size) = |
| 516 | + reg_tmp.xetla_select< |
| 517 | + payload_t::simd_exec_size, |
| 518 | + payload_t::num_channel>(iii); |
| 519 | + else // TODO (dingyi): Delete after driver fix |
| 520 | + reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>( |
| 521 | + iii * payload_t::simd_exec_size) = 0; |
| 522 | + } |
| 523 | + reg_sub |
| 524 | + .xetla_select<load_elems * pack_factor, 1>( |
| 525 | + sub_block_offset * tile_desc::block_size_x) |
| 526 | + .xetla_format<load_dtype>() = reg_tmp_trans; |
| 527 | + } else { |
| 528 | + reg_sub |
| 529 | + .xetla_select<load_elems * pack_factor, 1>( |
| 530 | + sub_block_offset * tile_desc::block_size_x) |
| 531 | + .xetla_format<load_dtype>() = reg_tmp; |
522 | 532 | }
|
523 |
| - reg_sub |
524 |
| - .xetla_select<load_elems * pack_factor, 1>( |
525 |
| - sub_block_offset * tile_desc::block_size_x) |
526 |
| - .xetla_format<load_dtype>() = reg_tmp_trans; |
527 |
| - } else { |
528 |
| - reg_sub |
529 |
| - .xetla_select<load_elems * pack_factor, 1>( |
530 |
| - sub_block_offset * tile_desc::block_size_x) |
531 |
| - .xetla_format<load_dtype>() = reg_tmp; |
532 | 533 | }
|
533 | 534 | }
|
534 |
| - // } |
535 | 535 | }
|
536 | 536 |
|
537 | 537 | if constexpr (payload_t::trans) {
|
|
0 commit comments