Wmma support for gemm_ab_scale (#3314)

* Support gemm_ab_scale:

 - Add tests
 - Integrate scaling implementation in multiple D
 - Generalize existing b_scale for ab_scale
 - Add instances
 - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK
 - Add support for all layouts supported by xdl
 - Fix splitk xdl

* Fix copyright

* Wmma support for gemm_blockscale_wp (#3315)

* Support for  preshuffle with ab scale

 - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale
 - add support for AScaleLayout amnd BScaleLayout (can be different
   from ALayout and BLayout, respectively)
 - add Run method in v1 pipeline to support preshuffle + scaling
 - add support for preshuffle gemms in common invoker
 - Add splitk support

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-11 09:06:20 +01:00
committed by GitHub
parent d66e5f667c
commit ce99cab605
51 changed files with 5144 additions and 552 deletions

View File

@@ -388,11 +388,11 @@ struct ABTransferThreadTiles
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(make_unmerge_transform(make_tuple(
Number<ABK0 / KRow>{}, KRow, Number<KPack / KRow / ABK1>{})),
make_unmerge_transform(make_tuple(
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}

View File

@@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
c_thread_buf.Clear();
// Empty BScale struct for the blockwise pipeline.
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
using ABScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = ABScale{};
auto b_scale_struct = ABScale{};
/*******************************************************************************/
//
@@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
b0_block_buf,
b0_block_slice_copy_step,
acc0_thread_buf,
a_scale_struct,
b_scale_struct,
KBlockMainLoop,
1); // num_k_block_per_scale

View File

@@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
using Scale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = Scale{};
auto b_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
@@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_id);

View File

@@ -23,6 +23,7 @@ template <typename ALayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename AScaleType,
typename BsDataType,
typename BScaleType,
typename AccDataType,
@@ -34,6 +35,7 @@ template <typename ALayout,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN, // scale N
index_t ScaleBlockK, // scale K
index_t MPerBlock,
@@ -65,13 +67,16 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
typename AScaleLayout = ALayout,
typename BScaleLayout = BLayout>
struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
@@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeB,
PermuteA,
PermuteB,
false,
IsBPreShuffled,
true>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
@@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeB,
PermuteA,
PermuteB,
false,
IsBPreShuffled,
true>;
using Base::I0;
@@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleA_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
@@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleA{StrideScaleA_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
@@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
NBlock{CalculateNBlock(N_)},
Kt{K_}
{
}
@@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
<< std::endl;
std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", "
<< "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
@@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleA;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
@@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t BK0;
index_t MBlock;
index_t NBlock;
index_t Kt;
};
// Argument
@@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleA_,
index_t StrideScaleB_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
@@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
StrideBs_,
StrideDs_,
StrideE_,
StrideScaleA_,
StrideScaleB_,
k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_a_scale_grid{p_a_scale_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
@@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AScaleType* p_a_scale_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
@@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
if constexpr(IsBPreShuffled)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
else
{
if constexpr(!PermuteB)
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
});
}
else
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
const int k0_offset = karg.KRead * karg.N;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
if constexpr(!PermuteB)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AScaleLayout>)
{
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AScaleLayout>)
{
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BScaleLayout>)
{
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BScaleLayout>)
{
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(k_id < karg.KBatch - 1)
@@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t scale_k_split_offset; // New member for scale matrix offset
index_t scale_a_k_split_offset; // A scale matrix offset
index_t scale_b_k_split_offset; // B scale matrix offset
index_t c_reduce_offset;
};
using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe;
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const BScaleType* p_b_scale_grid,
index_t block_n_id)
__device__ static constexpr auto
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
{
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, AScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
}
}
static constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
template <index_t NumberOfBuffers>
__device__ static auto
MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id)
{
if constexpr(ck::is_same_v<AScaleType, void>)
{
using AScale = typename BlockwiseGemmPipe::Empty;
return AScale{};
}
else
{
#if defined(__gfx11__)
// TODO: remove this restriction
static_assert(ScaleBlockM >= MPerWmma,
"ScaleBlockM must be greater equal than MPerWmma");
#endif
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
selected_wmma.k_per_wmma,
"ScaleBlockK must be greater equal than KPerWmma");
static constexpr auto ScaleSliceSizeN = NRepeat;
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
const auto a_scale_grid_desc_am_ak =
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
constexpr auto RegSizePerWmmaFull =
wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number;
constexpr auto RegSizePerWmma =
math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM);
auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
b_thread_offset_k / ScaleBlockK));
constexpr auto ScaleSliceSizeM =
ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma
: math::integer_divide_ceil(MPerBlock, ScaleBlockM);
constexpr auto ScaleSliceStrideM =
math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM);
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_scale_thread_desc.GetElementSpaceSize());
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
using BScale =
typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
ScaleSliceSizeK,
NWaves,
ScaleBlockK,
NumberOfBuffers,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
auto a_thread_offset_m =
((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) /
math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) +
(get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
constexpr index_t VectorDim =
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? 0 : 1;
constexpr index_t VectorSize =
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? RegSizePerWmma
: ScaleSliceSizeK;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<RegSizePerWmma, ScaleSliceSizeK>,
Sequence<0, 1>,
VectorDim,
VectorSize,
1,
true>(
a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0));
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleType>(
a_scale_thread_desc.GetElementSpaceSize());
using AScale =
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeM,
ScaleSliceStrideM,
ScaleSliceSizeK,
NumberOfBuffers,
RegSizePerWmma,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_copy),
decltype(a_scale_grid_buf),
decltype(a_scale_thread_buf),
decltype(a_scale_thread_desc)>;
return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf};
}
}
__device__ static constexpr auto
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
}
}
template <index_t NumberOfBuffers>
__device__ static auto
MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id)
{
if constexpr(ck::is_same_v<BScaleType, void>)
{
using BScale = typename BlockwiseGemmPipe::Empty;
return BScale{};
}
else
{
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
selected_wmma.k_per_wmma,
"ScaleBlockK must be greater equal than KPerWmma");
const auto b_scale_grid_desc_bn_ak =
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto ScaleSliceSizeN =
ScaleBlockN < NPerWmma ? NRepeat
: math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr auto ScaleSliceStrideN =
math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN);
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma) /
ScaleBlockN;
constexpr index_t VectorDim =
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 0 : 1;
constexpr index_t VectorSize =
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 1 : ScaleSliceSizeK;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
VectorDim,
VectorSize,
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0));
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleType>(
b_scale_thread_desc.GetElementSpaceSize());
using BScale =
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeN,
ScaleSliceStrideN,
ScaleSliceSizeK,
NumberOfBuffers,
1,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
}
}
__device__ static index_t GetKBlockPerScale()
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
if constexpr(ck::is_same_v<AScaleType, void> && ck::is_same_v<BScaleType, void>)
{
return 0;
}
else
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
}
}
template <bool HasMainKBlockLoop,
@@ -539,18 +720,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const AScaleType* p_a_scale_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct
auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id);
// BScale struct
auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id);
const index_t num_k_block_per_scale = GetKBlockPerScale();
@@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
epilogue_args,
k_id);
}
// NOTE: Wrapper function to have __global__ function in common
@@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
splitk_batch_offset.b_k_split_offset[i];
});
const AScaleType* p_a_scale_grid_ptr;
if constexpr(ck::is_same_v<AScaleType, void>)
{
p_a_scale_grid_ptr = karg.p_a_scale_grid;
}
else
{
p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset;
}
const BScaleType* p_b_scale_grid_ptr;
if constexpr(ck::is_same_v<BScaleType, void>)
{
p_b_scale_grid_ptr = karg.p_b_scale_grid;
}
else
{
p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset;
}
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_a_scale_grid_ptr,
p_b_scale_grid_ptr,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
epilogue_args,
k_id);
}
};

View File

@@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
@@ -162,7 +204,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
static constexpr index_t KInner = IsBPreShuffled ? KInnerB : ck::math::min(KInnerA, KInnerB);
static constexpr index_t KPack =
KInner *
@@ -966,6 +1008,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AScaleStruct,
typename BScaleStruct,
typename EpilogueArgument,
bool HasMainKBlockLoop,
@@ -988,6 +1031,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
@@ -1072,6 +1116,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
a_scale_struct,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);

View File

@@ -43,13 +43,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
}
}
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
__host__ __device__ static constexpr auto
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
{
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
}
}
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
__host__ __device__ static constexpr auto
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
}
}
@@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t StrideScaleA_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
N{N_},
@@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideC{StrideC_},
StrideScaleA{StrideScaleA_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideC;
index_t StrideScaleA;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t StrideScaleA_,
index_t StrideScaleB_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideDs_,
StrideC_,
StrideScaleA_,
StrideScaleB_,
k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{},
@@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
b_k_split_offset = blockIdx.z * karg.KRead;
}
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
scale_a_k_split_offset =
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_b_k_split_offset =
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
@@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t scale_a_k_split_offset; // A scale matrix offset
index_t scale_b_k_split_offset; // B scale matrix offset
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
const auto a_scale_grid_desc_am_ak =
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
const auto b_scale_grid_desc_bn_ak =
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(