Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit ea1267d

Browse files
committed
fix sdp bug
1 parent 8f0abc4 commit ea1267d

File tree

5 files changed

+81
-81
lines changed

5 files changed

+81
-81
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ if (${LOG} STREQUAL "on")
4646
endif ()
4747

4848
# For large registers mode, enable 256 registers for kernels
49-
set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
49+
# set(XETLA_OFFLINE_OPTIONS "-doubleGRF")
5050
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-disable-indvars-opt")
5151
set(XETLA_OFFLINE_OPTIONS "${XETLA_OFFLINE_OPTIONS} -vc-codegen")
5252
# Enable bank conflict reduction.

include/experimental/group/gemm/compute_policy.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ struct compute_policy_int4_dequantize<
137137
quant_info_.weight_mem_layout == mem_layout::col_major;
138138

139139
static constexpr uint32_t block_size_y_a = is_col_major_b ? 8 : 16;
140-
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 256 : 32;
140+
static constexpr uint32_t block_bytes_x_a = is_col_major_b ? 128 : 32;
141141
static constexpr uint32_t block_size_x_a =
142142
block_bytes_x_a / sizeof(dtype_mma_a);
143143
static constexpr uint32_t block_size_x_b = is_col_major_b ? 1 : 32;
144-
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 256 : 32;
144+
static constexpr uint32_t block_bytes_y_b = is_col_major_b ? 128 : 32;
145145
static constexpr uint32_t block_size_y_b =
146146
block_bytes_y_b / sizeof(dtype_mma_b);
147147

include/subgroup/tile/impl/load_xe.hpp

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
namespace gpu::xetla::subgroup {
2727

2828
namespace detail {
29-
template <typename tile_t, typename payload_t, bool is_lsc_gather_ = false>
29+
template <typename tile_t, typename payload_t, bool is_lsc_gather_ = true>
3030
struct check_load_type {
3131
static constexpr bool is_lsc_gather = is_lsc_gather_;
3232
static constexpr bool is_global_block_2d =
@@ -465,73 +465,73 @@ tile_load(tile_t& tile, payload_t& payload) {
465465
uint32_t offset_x = j * tile_desc::block_size_x;
466466
auto reg_sub = tile.reg.xetla_select<tile_desc::block_elems, 1>(
467467
(i * tile_desc::num_block_x + j) * tile_desc::block_elems);
468-
// #pragma unroll
469-
// for (uint32_t sub_block_offset = 0; sub_block_offset <
470-
// (payload_t::mem_transpose ? tile_desc::block_size_x
471-
// : tile_desc::block_size_y);
472-
// sub_block_offset += num_channel) {
473-
uint32_t sub_block_offset = 0;
474-
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
475-
uint32_t address_offset = payload_t::mem_transpose
476-
? (offset_x + sub_block_offset) * payload.pitch_in_bytes +
477-
offset_y * sizeof(dtype)
478-
: offset_x * sizeof(dtype) +
479-
(offset_y + sub_block_offset) * payload.pitch_in_bytes;
480-
xetla_mask<num_channel> pred = 1;
481-
if constexpr (num_channel > 1) {
482-
// For SDP load, need pred
483-
const uint32_t sub_block_offset_x = payload.base_x + offset_x +
484-
(payload_t::mem_transpose ? sub_block_offset : 0);
485-
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
486-
(payload_t::mem_transpose ? 0 : sub_block_offset);
487-
const auto offset_ch_dim =
488-
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
489-
const auto size_ch_dim =
490-
payload_t::trans ? payload.width_in_elems : payload.height_in_elems;
491-
492-
pred = offset_ch_dim + num_channel > size_ch_dim
493-
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
494-
size_ch_dim)
495-
: 1;
496-
}
497-
reg_tmp = xetla_load_global<
498-
load_dtype,
499-
payload_t::simd_exec_size,
500-
data_size::default_size,
501-
L1,
502-
L2,
503-
payload_t::num_channel>(
504-
payload.base_ptr,
505-
payload.channel_offset + payload.base_offset + address_offset,
506-
pred);
507-
508-
if constexpr (
509-
payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) {
510-
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
511468
#pragma unroll
512-
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
513-
if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix
514-
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
515-
iii * payload_t::simd_exec_size) =
516-
reg_tmp.xetla_select<
517-
payload_t::simd_exec_size,
518-
payload_t::num_channel>(iii);
519-
else // TODO (dingyi): Delete after driver fix
520-
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
521-
iii * payload_t::simd_exec_size) = 0;
469+
for (uint32_t sub_block_offset = 0; sub_block_offset <
470+
(payload_t::mem_transpose ? tile_desc::block_size_x
471+
: tile_desc::block_size_y);
472+
sub_block_offset += num_channel) {
473+
// uint32_t sub_block_offset = 0;
474+
xetla_vector<load_dtype, load_elems> reg_tmp = 0;
475+
uint32_t address_offset = payload_t::mem_transpose
476+
? (offset_x + sub_block_offset) * payload.pitch_in_bytes +
477+
offset_y * sizeof(dtype)
478+
: offset_x * sizeof(dtype) +
479+
(offset_y + sub_block_offset) * payload.pitch_in_bytes;
480+
xetla_mask<num_channel> pred = 1;
481+
if constexpr (num_channel > 1) {
482+
// For SDP load, need pred
483+
const uint32_t sub_block_offset_x = payload.base_x + offset_x +
484+
(payload_t::mem_transpose ? sub_block_offset : 0);
485+
const uint32_t sub_block_offset_y = payload.base_y + offset_y +
486+
(payload_t::mem_transpose ? 0 : sub_block_offset);
487+
const auto offset_ch_dim =
488+
payload_t::trans ? sub_block_offset_x : sub_block_offset_y;
489+
const auto size_ch_dim = payload_t::trans ? payload.width_in_elems
490+
: payload.height_in_elems;
491+
492+
pred = offset_ch_dim + num_channel > size_ch_dim
493+
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
494+
size_ch_dim)
495+
: 1;
496+
}
497+
reg_tmp = xetla_load_global<
498+
load_dtype,
499+
payload_t::simd_exec_size,
500+
data_size::default_size,
501+
L1,
502+
L2,
503+
num_channel>(
504+
payload.base_ptr,
505+
payload.channel_offset + payload.base_offset + address_offset,
506+
pred);
507+
508+
if constexpr (
509+
payload_t::simd_exec_size > 1 && payload_t::num_channel > 1) {
510+
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
511+
#pragma unroll
512+
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
513+
if ((bool)pred[iii]) // TODO (dingyi): Delete after driver fix
514+
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
515+
iii * payload_t::simd_exec_size) =
516+
reg_tmp.xetla_select<
517+
payload_t::simd_exec_size,
518+
payload_t::num_channel>(iii);
519+
else // TODO (dingyi): Delete after driver fix
520+
reg_tmp_trans.xetla_select<payload_t::simd_exec_size, 1>(
521+
iii * payload_t::simd_exec_size) = 0;
522+
}
523+
reg_sub
524+
.xetla_select<load_elems * pack_factor, 1>(
525+
sub_block_offset * tile_desc::block_size_x)
526+
.xetla_format<load_dtype>() = reg_tmp_trans;
527+
} else {
528+
reg_sub
529+
.xetla_select<load_elems * pack_factor, 1>(
530+
sub_block_offset * tile_desc::block_size_x)
531+
.xetla_format<load_dtype>() = reg_tmp;
522532
}
523-
reg_sub
524-
.xetla_select<load_elems * pack_factor, 1>(
525-
sub_block_offset * tile_desc::block_size_x)
526-
.xetla_format<load_dtype>() = reg_tmp_trans;
527-
} else {
528-
reg_sub
529-
.xetla_select<load_elems * pack_factor, 1>(
530-
sub_block_offset * tile_desc::block_size_x)
531-
.xetla_format<load_dtype>() = reg_tmp;
532533
}
533534
}
534-
// }
535535
}
536536

537537
if constexpr (payload_t::trans) {

include/subgroup/tile/impl/payload_xe.hpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,24 +1152,24 @@ struct mem_payload_t<
11521152
uint32_t,
11531153
dtype>::type>::type;
11541154
static constexpr uint32_t pack_factor = sizeof(mem_dtype) / sizeof(dtype);
1155-
// for pvc, we can use simd16 or simd32
1156-
static constexpr uint32_t min_store_bytes = 16 * sizeof(dtype);
1157-
static constexpr uint32_t max_store_bytes = 32 * sizeof(dtype);
1158-
static constexpr uint32_t simd_channel =
1159-
((tile_bytes % max_store_bytes) == 0 &&
1160-
(block_bytes % max_store_bytes) == 0)
1161-
? 32
1162-
: 16;
1163-
static constexpr uint32_t num_channel = mem_transpose
1164-
? (simd_channel >= block_size_x) ? block_size_x : simd_channel
1165-
: (simd_channel >= block_size_y) ? block_size_y
1166-
: simd_channel;
11671155

11681156
static constexpr uint32_t simd_exec_size =
11691157
(mem_transpose ? block_size_y : block_size_x) >= pack_factor
11701158
? (mem_transpose ? block_size_y : block_size_x) / pack_factor
11711159
: 1;
11721160

1161+
// for pvc, we can use simd16 or simd32
1162+
using load_store_attr = load_store_attr_t<msg_type::block_1d, arch_tag>;
1163+
static constexpr uint32_t max_bytes =
1164+
load_store_attr::max_load_vec_len;
1165+
1166+
static constexpr uint32_t simd_channel =
1167+
max_bytes / (simd_exec_size * sizeof(mem_dtype));
1168+
1169+
static constexpr uint32_t num_channel = mem_transpose
1170+
? std::min(block_size_x, simd_channel)
1171+
: std::min(block_size_y, simd_channel);
1172+
11731173
xetla_vector<uint32_t, num_channel> channel_offset;
11741174
xetla_vector<uint32_t, num_channel> step_x;
11751175
xetla_vector<uint32_t, num_channel> step_y;

tests/integration/gemv/int4/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class test_col_major_1 {
3737
static constexpr size_t wg_n = 1;
3838
static constexpr size_t sg_m = 1;
3939
static constexpr size_t sg_n = 1;
40-
static constexpr size_t sg_k = 1024 / 1;
40+
static constexpr size_t sg_k = 512 / 1;
4141
static constexpr size_t dequant_s = 128;
4242
// static constexpr quant_mode quant_mode = quant_mode::S4_ASYM;
4343
static constexpr quant_mode quant_mode = quant_mode::S4_FULLRANGE_NO_ZP;
@@ -47,7 +47,7 @@ class test_col_major_1 {
4747
static constexpr mem_layout layout_a = mem_layout::row_major;
4848
static constexpr mem_layout layout_b = mem_layout::col_major;
4949
static constexpr mma_engine mma_eng = mma_engine::fpu;
50-
static constexpr gpu_arch arch = gpu_arch::XeHpc;
50+
static constexpr gpu_arch arch = gpu_arch::XeLpg;
5151
using data_type_a = fp16;
5252
using data_type_b = int4x8;
5353
using data_type_c = fp16;

0 commit comments

Comments
 (0)