Fix splitk preshuffle (#3137)

* Fix splitK multiply_multiply_wp

* Add tests for gemm_multiply_multiply_wp

* Add tests for gemm_universal_preshuffle (KBatch = 1)

* Add tests gemm_blockscale_wp

* Fix splitk gemm universal preshuffle

* Run new tests on arch supporting fp8

* Restore example

* Fix strides profiler

* Fix tests

* Fix clang format

* Finalize profiler preshuffle with tolerances

* Minor improvements to splitk related changes

* Address review comments: clang format and ckProfiler typo

* Remove b_k_split_offset from SplitKBatchOffset struct
This commit is contained in:
Enrico Degregori
2025-11-03 20:59:01 +01:00
committed by GitHub
parent 057b7d43b4
commit 507d81c3af
19 changed files with 777 additions and 172 deletions

View File

@@ -40,14 +40,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// Full K needed for matrix B
const index_t Kt = karg.K;
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared,
karg);
karg,
k_id,
Kt);
}
#else
ignore = karg;
@@ -74,15 +82,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// Full K needed for matrix B
const index_t Kt = karg.K;
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
p_shared_0,
p_shared_1,
karg);
karg,
k_id,
Kt);
}
#else
ignore = karg;
@@ -658,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
// b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
}
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
@@ -697,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t c_reduce_offset;
};
@@ -900,6 +896,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
(NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
"Invalid tuning param!");
if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0)
{
return false;
}
if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding ||
GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding ||
@@ -1134,7 +1135,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -1226,7 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
@@ -1465,10 +1467,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
const BDataType* p_b_grid,
CDataType* p_c_grid,
void* p_shared,
const Problem& problem)
const Problem& problem,
const index_t k_id,
const index_t Kt)
{
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
@@ -1491,7 +1495,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bpreshuffled,
c_grid_desc_mblock_mperblock_nblock_nperblock);
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id);
}
template <typename AGridDesc_AK0_M_K1,
@@ -1509,7 +1514,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
c_grid_desc_mblock_mperblock_nblock_nperblock)
c_grid_desc_mblock_mperblock_nblock_nperblock,
const index_t k_id)
{
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
@@ -1606,7 +1612,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPack * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
@@ -1849,10 +1855,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
CDataType* p_c_grid,
void* p_shared_0,
void* p_shared_1,
const Problem& problem)
const Problem& problem,
const index_t k_id,
const index_t Kt)
{
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bpreshuffled =
@@ -1877,7 +1885,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle
problem,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bpreshuffled,
c_grid_desc_mblock_mperblock_nblock_nperblock);
c_grid_desc_mblock_mperblock_nblock_nperblock,
k_id);
}
};

View File

@@ -43,18 +43,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// Full K needed for matrix B
const index_t Kt = karg.K;
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
karg.c_element_op,
k_id,
Kt);
}
#else
ignore = karg;
@@ -79,11 +87,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
__shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// Full K needed for matrix B
const index_t Kt = karg.K;
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K);
const index_t k_id = blockIdx.z * num_k_per_block;
GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_b_grid,
karg.p_ds_grid,
karg.p_c_grid,
p_shared,
@@ -91,7 +105,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
karg,
karg.a_element_op,
karg.b_element_op,
karg.c_element_op);
karg.c_element_op,
k_id,
Kt);
}
#else
ignore = karg;
@@ -691,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_k_split_offset = k_id * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = k_id * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
// KPack * NLane * KLane * K0 * N0
b_k_split_offset = k_id * karg.KRead * NLane;
}
if(k_id < karg.KBatch - 1)
{
karg.K = karg.KRead;
@@ -712,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
}
index_t a_k_split_offset;
index_t b_k_split_offset;
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -1163,7 +1168,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const index_t k_id,
const index_t Kt)
{
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -1176,7 +1183,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
block_2_ctile_map,
k_id,
Kt);
}
template <typename Block2CTileMap,
@@ -1192,11 +1201,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const index_t k_id,
const index_t Kt)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
@@ -1293,7 +1304,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment
@@ -1597,7 +1608,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
CElementwiseOperation c_element_op,
const index_t k_id,
const index_t Kt)
{
const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4};
Run_2Lds<Block2CTileMapDefault, HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
@@ -1611,7 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
a_element_op,
b_element_op,
c_element_op,
block_2_ctile_map);
block_2_ctile_map,
k_id,
Kt);
}
template <typename Block2CTileMap,
@@ -1628,11 +1643,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
const Block2CTileMap& block_2_ctile_map)
const Block2CTileMap& block_2_ctile_map,
const index_t k_id,
const index_t Kt)
{
ignore = b_element_op;
index_t BN0Shuffled = CalculateBN0Shuffled(problem.N);
index_t BK0Shuffled = CalculateBK0Shuffled(problem.K);
index_t BK0Shuffled = CalculateBK0Shuffled(Kt);
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
@@ -1731,7 +1748,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle
true>(b_grid_desc_bpreshuffled,
make_multi_index(n_block_data_idx_on_grid,
get_warp_local_1d_id() % NWave,
0,
k_id,
KPackPerGroup * (get_thread_local_1d_id() % WarpSize)));
// LDS allocation for A and B: be careful of alignment