Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 4f30651

Browse files
committed
fix sdp bug
1 parent 8f0abc4 commit 4f30651

File tree

6 files changed

+86
-91
lines changed

6 files changed

+86
-91
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ if (${LOG} STREQUAL "on")
4646
endif ()
4747

4848
# For large registers mode, enable 256 registers for kernels
49-
set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
49+
# set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
5050
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
5151
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen")
5252
# Enable bank conflict reduction.

include/experimental/group/gemm/compute_policy.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ struct compute_policy_int4_dequantize<
137137
quant_info_.weight_mem_layout == mem_layout::col_major;
138138

139139
static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16;
140-
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32;
140+
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 128 : 32;
141141
static constexpr uint32_t block_size_x_a =
142142
block_bytes_x_a / sizeof(dtype_mma_a);
143143
static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32;
144-
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32;
144+
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 128 : 32;
145145
static constexpr uint32_t block_size_y_b =
146146
block_bytes_y_b / sizeof(dtype_mma_b);
147147

include/experimental/group/gemm/impl/int4_dequantize_xe.hpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,7 @@ class gemm_t<
204204
using matB_payload_t = subgroup::mem_payload_t<
205205
mem_desc_b_t,
206206
matB_tile_desc_t,
207-
subgroup::msg_type_v<
208-
matB_tile_desc_t,
209-
mem_desc_t<
210-
typename mem_desc_b_t::dtype,
211-
mem_layout::row_major,
212-
mem_desc_b_t::space>>,
213-
// subgroup::msg_type_v<matB_tile_desc_t, mem_desc_b_t>,
207+
subgroup::msg_type_v<matB_tile_desc_t, mem_desc_b_t>,
214208
arch_tag>;
215209
using matB_prefetch_payload_t = subgroup::
216210
prefetch_payload_t<mem_desc_b_t, matB_tile_desc_t, wg_size_y, arch_tag>;

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 69 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
namespace gpu::xetla::subgroup {
2727

2828
namespace detail {
29-
template <typename tile_t, typename payload_t, bool is_lsc_gather_ = false>
29+
template <typename tile_t, typename payload_t, bool is_lsc_gather_ = true>
3030
struct check_load_type {
3131
static constexpr bool is_lsc_gather = is_lsc_gather_;
3232
static constexpr bool is_global_block_2d =
@@ -398,8 +398,10 @@ tile_load(tile_t& tile, payload_t& payload) {
398398
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
399399

400400
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
401-
static constexpr uint32_t max_load_vec_len =
402-
load_store_attr::max_load_vec_len;
401+
static constexpr uint32_t max_load_vec_len = std::min(
402+
uint32_t(tile_t::block_elems * sizeof(dtype)),
403+
load_store_attr::max_load_vec_len);
404+
403405
static constexpr uint32_t max_load_vec_elems =
404406
max_load_vec_len / sizeof(dtype);
405407

@@ -465,73 +467,73 @@ tile_load(tile_t& tile, payload_t& payload) {
465467
uint32_t offset_x = j * tile_desc::block_size_x;
466468
auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
467469
(i * tile_desc::num_block_x + j) * tile_desc::block_elems);
468-
// #pragma unroll
469-
// for (uint32_t sub_block_offset = 0; sub_block_offset <
470-
// (payload_t::mem_transpose ? tile_desc::block_size_x
471-
// : tile_desc::block_size_y);
472-
// sub_block_offset += num_channel) {
473-
uint32_t sub_block_offset = 0;
474-
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
475-
uint32_t address_offset = payload_t::mem_transpose
476-
? (offset_x + sub_block_offset) * payload.pitch_in_bytes +
477-
offset_y * sizeof(dtype)
478-
: offset_x * sizeof(dtype) +
479-
(offset_y + sub_block_offset) * payload.pitch_in_bytes;
480-
xetla_mask<num_channel> pred = 1;
481-
if constexpr (num_channel > 1) {
482-
// For SDP load, need pred
483-
const uint32_t sub_block_offset_x = payload.base_x + offset_x +
484-
(payload_t::mem_transpose ? sub_block_offset : 0);
485-
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
486-
(payload_t::mem_transpose ? 0 : sub_block_offset);
487-
const auto offset_ch_dim =
488-
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
489-
const auto size_ch_dim =
490-
payload_t::trans ? payload.width_in_elems : payload.height_in_elems;
491-
492-
pred = offset_ch_dim + num_channel > size_ch_dim
493-
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
494-
size_ch_dim)
495-
: 1;
496-
}
497-
reg_tmp = xetla_load_global<
498-
load_dtype,
499-
payload_t::simd_exec_size,
500-
data_size::default_size,
501-
L1,
502-
L2,
503-
payload_t::num_channel>(
504-
payload.base_ptr,
505-
payload.channel_offset + payload.base_offset + address_offset,
506-
pred);
507-
508-
if constexpr (
509-
payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) {
510-
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
511470
#pragma unroll
512-
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
513-
if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix
514-
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
515-
iii * payload_t::simd_exec_size) =
516-
reg_tmp.xetla_select<
517-
payload_t::simd_exec_size,
518-
payload_t::num_channel>(iii);
519-
else // TODO (dingyi): Delete after driver fix
520-
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
521-
iii * payload_t::simd_exec_size) = 0;
471+
for (uint32_t sub_block_offset = 0; sub_block_offset <
472+
(payload_t::mem_transpose ? tile_desc::block_size_x
473+
: tile_desc::block_size_y);
474+
sub_block_offset += num_channel) {
475+
// uint32_t sub_block_offset = 0;
476+
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
477+
uint32_t address_offset = payload_t::mem_transpose
478+
? (offset_x + sub_block_offset) * payload.pitch_in_bytes +
479+
offset_y * sizeof(dtype)
480+
: offset_x * sizeof(dtype) +
481+
(offset_y + sub_block_offset) * payload.pitch_in_bytes;
482+
xetla_mask<num_channel> pred = 1;
483+
if constexpr (num_channel > 1) {
484+
// For SDP load, need pred
485+
const uint32_t sub_block_offset_x = payload.base_x + offset_x +
486+
(payload_t::mem_transpose ? sub_block_offset : 0);
487+
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
488+
(payload_t::mem_transpose ? 0 : sub_block_offset);
489+
const auto offset_ch_dim =
490+
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
491+
const auto size_ch_dim = payload_t::trans ? payload.width_in_elems
492+
: payload.height_in_elems;
493+
494+
pred = offset_ch_dim + num_channel > size_ch_dim
495+
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
496+
size_ch_dim)
497+
: 1;
498+
}
499+
reg_tmp = xetla_load_global<
500+
load_dtype,
501+
payload_t::simd_exec_size,
502+
data_size::default_size,
503+
L1,
504+
L2,
505+
num_channel>(
506+
payload.base_ptr,
507+
payload.channel_offset + payload.base_offset + address_offset,
508+
pred);
509+
510+
if constexpr (
511+
payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) {
512+
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
513+
#pragma unroll
514+
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
515+
if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix
516+
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
517+
iii * payload_t::simd_exec_size) =
518+
reg_tmp.xetla_select<
519+
payload_t::simd_exec_size,
520+
payload_t::num_channel>(iii);
521+
else // TODO (dingyi): Delete after driver fix
522+
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
523+
iii * payload_t::simd_exec_size) = 0;
524+
}
525+
reg_sub
526+
.xetla_select<load_elems * pack_factor, 1>(
527+
sub_block_offset * tile_desc::block_size_x)
528+
.xetla_format<load_dtype>() = reg_tmp_trans;
529+
} else {
530+
reg_sub
531+
.xetla_select<load_elems * pack_factor, 1>(
532+
sub_block_offset * tile_desc::block_size_x)
533+
.xetla_format<load_dtype>() = reg_tmp;
522534
}
523-
reg_sub
524-
.xetla_select<load_elems * pack_factor, 1>(
525-
sub_block_offset * tile_desc::block_size_x)
526-
.xetla_format<load_dtype>() = reg_tmp_trans;
527-
} else {
528-
reg_sub
529-
.xetla_select<load_elems * pack_factor, 1>(
530-
sub_block_offset * tile_desc::block_size_x)
531-
.xetla_format<load_dtype>() = reg_tmp;
532535
}
533536
}
534-
// }
535537
}
536538

537539
if constexpr (payload_t::trans) {
@@ -595,7 +597,7 @@ tile_load(tile_t& tile, payload_t& payload) {
595597

596598
reg_sub.xetla_select<load_elems, 1>(
597599
sub_block_y * tile_desc::block_size_x) =
598-
xetla_load_global<dtype, load_elems>(
600+
xetla_load_global<dtype, load_elems, L1, L2>(
599601
(dtype*)payload.base_ptr, payload.base_offset + address_offset);
600602
}
601603
}

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,24 +1152,23 @@ struct mem_payload_t<
11521152
uint32_t,
11531153
dtype>::type>::type;
11541154
static constexpr uint32_t pack_factor = sizeof(mem_dtype) / sizeof(dtype);
1155-
// for pvc, we can use simd16 or simd32
1156-
static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
1157-
static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
1158-
static constexpr uint32_t simd_channel =
1159-
((tile_bytes % max_store_bytes) == 0 &&
1160-
(block_bytes % max_store_bytes) == 0)
1161-
? 32
1162-
: 16;
1163-
static constexpr uint32_t num_channel = mem_transpose
1164-
? (simd_channel >= block_size_x) ? block_size_x : simd_channel
1165-
: (simd_channel >= block_size_y) ? block_size_y
1166-
: simd_channel;
11671155

11681156
static constexpr uint32_t simd_exec_size =
11691157
(mem_transpose ? block_size_y : block_size_x) >= pack_factor
11701158
? (mem_transpose ? block_size_y : block_size_x) / pack_factor
11711159
: 1;
11721160

1161+
// for pvc, we can use simd16 or simd32
1162+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
1163+
static constexpr uint32_t max_bytes = load_store_attr::max_load_vec_len;
1164+
1165+
static constexpr uint32_t simd_channel =
1166+
max_bytes / (simd_exec_size * sizeof(mem_dtype));
1167+
1168+
static constexpr uint32_t num_channel = mem_transpose
1169+
? std::min(block_size_x, simd_channel)
1170+
: std::min(block_size_y, simd_channel);
1171+
11731172
xetla_vector<uint32_t, num_channel> channel_offset;
11741173
xetla_vector<uint32_t, num_channel> step_x;
11751174
xetla_vector<uint32_t, num_channel> step_y;

tests/integration/gemv/int4/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class test_col_major_1 {
3737
static constexpr size_t wg_n = 1;
3838
static constexpr size_t sg_m = 1;
3939
static constexpr size_t sg_n = 1;
40-
static constexpr size_t sg_k = 1024 / 1;
40+
static constexpr size_t sg_k = 512 / 1;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
@@ -47,7 +47,7 @@ class test_col_major_1 {
4747
static constexpr mem_layout layout_a = mem_layout::row_major;
4848
static constexpr mem_layout layout_b = mem_layout::col_major;
4949
static constexpr mma_engine mma_eng = mma_engine::fpu;
50-
static constexpr gpu_arch arch = gpu_arch::XeHpc;
50+
static constexpr gpu_arch arch = gpu_arch::XeLpg;
5151
using data_type_a = fp16;
5252
using data_type_b = int4x8;
5353
using data_type_c = fp16;

0 commit comments

Comments
 (0)