mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Wmma support for gemm_ab_scale (#3314)
* Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header
This commit is contained in:
@@ -388,11 +388,11 @@ struct ABTransferThreadTiles
|
||||
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<ABK0 / KRow>{}, KRow, Number<KPack / KRow / ABK1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
|
||||
@@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Empty BScale struct for the blockwise pipeline.
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
using ABScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = ABScale{};
|
||||
auto b_scale_struct = ABScale{};
|
||||
|
||||
/*******************************************************************************/
|
||||
//
|
||||
@@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
b0_block_buf,
|
||||
b0_block_slice_copy_step,
|
||||
acc0_thread_buf,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
KBlockMainLoop,
|
||||
1); // num_k_block_per_scale
|
||||
|
||||
@@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
|
||||
|
||||
// BScale struct (Empty)
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
auto b_scale_struct = BScale{};
|
||||
using Scale = typename BlockwiseGemmPipe::Empty;
|
||||
auto a_scale_struct = Scale{};
|
||||
auto b_scale_struct = Scale{};
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
@@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
@@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args,
|
||||
k_id);
|
||||
|
||||
@@ -23,6 +23,7 @@ template <typename ALayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename AsDataType,
|
||||
typename AScaleType,
|
||||
typename BsDataType,
|
||||
typename BScaleType,
|
||||
typename AccDataType,
|
||||
@@ -34,6 +35,7 @@ template <typename ALayout,
|
||||
typename CDEElementwiseOperation,
|
||||
tensor_operation::device::GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN, // scale N
|
||||
index_t ScaleBlockK, // scale K
|
||||
index_t MPerBlock,
|
||||
@@ -65,13 +67,16 @@ template <typename ALayout,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
|
||||
typename ComputeTypeA = EDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
bool PermuteA = false,
|
||||
bool PermuteB = false>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
typename ComputeTypeA,
|
||||
typename ComputeTypeB,
|
||||
bool PermuteA,
|
||||
bool PermuteB,
|
||||
bool IsBPreShuffled = false,
|
||||
typename AScaleLayout = ALayout,
|
||||
typename BScaleLayout = BLayout>
|
||||
struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
|
||||
: GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
ALayout,
|
||||
BLayout,
|
||||
@@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
false,
|
||||
IsBPreShuffled,
|
||||
true>
|
||||
{
|
||||
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
|
||||
@@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
ComputeTypeB,
|
||||
PermuteA,
|
||||
PermuteB,
|
||||
false,
|
||||
IsBPreShuffled,
|
||||
true>;
|
||||
|
||||
using Base::I0;
|
||||
@@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
@@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
StrideBs{StrideBs_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideE{StrideE_},
|
||||
StrideScaleA{StrideScaleA_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
@@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
AK0{CalculateAK0Padded(K_, KBatch_)},
|
||||
BK0{CalculateBK0Padded(K_, KBatch_)},
|
||||
MBlock{CalculateMBlock(M_)},
|
||||
NBlock{CalculateNBlock(N_)}
|
||||
NBlock{CalculateNBlock(N_)},
|
||||
Kt{K_}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
});
|
||||
std::cout << " }, ";
|
||||
}
|
||||
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
|
||||
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
|
||||
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
|
||||
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
|
||||
<< std::endl;
|
||||
std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", "
|
||||
<< "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
|
||||
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
|
||||
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
|
||||
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
|
||||
}
|
||||
|
||||
index_t M;
|
||||
@@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
std::array<index_t, NumBTensor> StrideBs;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideE;
|
||||
index_t StrideScaleA;
|
||||
index_t StrideScaleB;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
@@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
index_t BK0;
|
||||
index_t MBlock;
|
||||
index_t NBlock;
|
||||
index_t Kt;
|
||||
};
|
||||
|
||||
// Argument
|
||||
@@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
const AScaleType* p_a_scale_grid_,
|
||||
const BScaleType* p_b_scale_grid_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
@@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
StrideBs_,
|
||||
StrideDs_,
|
||||
StrideE_,
|
||||
StrideScaleA_,
|
||||
StrideScaleB_,
|
||||
k_batch_},
|
||||
p_as_grid{},
|
||||
p_bs_grid{},
|
||||
p_ds_grid{},
|
||||
p_e_grid{p_e_grid_},
|
||||
p_a_scale_grid{p_a_scale_grid_},
|
||||
p_b_scale_grid{p_b_scale_grid_},
|
||||
a_element_op{a_element_op_},
|
||||
b_element_op{b_element_op_},
|
||||
@@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
DsGridPointer p_ds_grid;
|
||||
EDataType* p_e_grid;
|
||||
|
||||
const AScaleType* p_a_scale_grid;
|
||||
const BScaleType* p_b_scale_grid;
|
||||
const AElementwiseOperation a_element_op;
|
||||
const BElementwiseOperation b_element_op;
|
||||
@@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
|
||||
}
|
||||
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
if constexpr(IsBPreShuffled)
|
||||
{
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
else
|
||||
{
|
||||
if constexpr(!PermuteB)
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
|
||||
static_for<0, NumBTensor, 1>{}([&](auto i) {
|
||||
b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
|
||||
});
|
||||
}
|
||||
else
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
|
||||
if constexpr(!PermuteB)
|
||||
{
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
|
||||
}
|
||||
else
|
||||
{
|
||||
const int k0_offset = karg.KRead * karg.N;
|
||||
static_for<0, NumBTensor, 1>{}(
|
||||
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AScaleLayout>)
|
||||
{
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
|
||||
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AScaleLayout>)
|
||||
{
|
||||
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BScaleLayout>)
|
||||
{
|
||||
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BScaleLayout>)
|
||||
{
|
||||
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(k_id < karg.KBatch - 1)
|
||||
@@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
|
||||
std::array<index_t, NumATensor> a_k_split_offset;
|
||||
std::array<index_t, NumBTensor> b_k_split_offset;
|
||||
index_t scale_k_split_offset; // New member for scale matrix offset
|
||||
index_t scale_a_k_split_offset; // A scale matrix offset
|
||||
index_t scale_b_k_split_offset; // B scale matrix offset
|
||||
index_t c_reduce_offset;
|
||||
};
|
||||
|
||||
using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe;
|
||||
|
||||
// return block_id to C matrix tile idx (m0, n0) mapping
|
||||
// if arch = gfx942
|
||||
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
|
||||
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
|
||||
|
||||
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
|
||||
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
index_t block_n_id)
|
||||
__device__ static constexpr auto
|
||||
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
|
||||
{
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, AScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr auto wmma =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
|
||||
static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
|
||||
template <index_t NumberOfBuffers>
|
||||
__device__ static auto
|
||||
MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id)
|
||||
{
|
||||
if constexpr(ck::is_same_v<AScaleType, void>)
|
||||
{
|
||||
using AScale = typename BlockwiseGemmPipe::Empty;
|
||||
return AScale{};
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
// TODO: remove this restriction
|
||||
static_assert(ScaleBlockM >= MPerWmma,
|
||||
"ScaleBlockM must be greater equal than MPerWmma");
|
||||
#endif
|
||||
static_assert(
|
||||
ScaleBlockK >=
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
|
||||
selected_wmma.k_per_wmma,
|
||||
"ScaleBlockK must be greater equal than KPerWmma");
|
||||
|
||||
static constexpr auto ScaleSliceSizeN = NRepeat;
|
||||
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
|
||||
const auto a_scale_grid_desc_am_ak =
|
||||
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
constexpr auto wmma =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
|
||||
constexpr auto RegSizePerWmmaFull =
|
||||
wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number;
|
||||
constexpr auto RegSizePerWmma =
|
||||
math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM);
|
||||
|
||||
auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
|
||||
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
|
||||
auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
|
||||
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
ScaleSliceSizeK,
|
||||
1,
|
||||
false>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
|
||||
b_thread_offset_k / ScaleBlockK));
|
||||
constexpr auto ScaleSliceSizeM =
|
||||
ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma
|
||||
: math::integer_divide_ceil(MPerBlock, ScaleBlockM);
|
||||
constexpr auto ScaleSliceStrideM =
|
||||
math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM);
|
||||
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
using BScale =
|
||||
typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
|
||||
ScaleSliceSizeK,
|
||||
NWaves,
|
||||
ScaleBlockK,
|
||||
NumberOfBuffers,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_copy),
|
||||
decltype(b_scale_grid_buf),
|
||||
decltype(b_scale_thread_buf),
|
||||
decltype(b_scale_thread_desc)>;
|
||||
auto a_thread_offset_m =
|
||||
((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) /
|
||||
math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) +
|
||||
(get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM;
|
||||
|
||||
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
|
||||
constexpr index_t VectorDim =
|
||||
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? 0 : 1;
|
||||
constexpr index_t VectorSize =
|
||||
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? RegSizePerWmma
|
||||
: ScaleSliceSizeK;
|
||||
|
||||
auto a_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<AScaleType,
|
||||
AScaleType,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_desc),
|
||||
Sequence<RegSizePerWmma, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
VectorDim,
|
||||
VectorSize,
|
||||
1,
|
||||
true>(
|
||||
a_scale_grid_desc_am_ak,
|
||||
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0));
|
||||
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
using AScale =
|
||||
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeM,
|
||||
ScaleSliceStrideM,
|
||||
ScaleSliceSizeK,
|
||||
NumberOfBuffers,
|
||||
RegSizePerWmma,
|
||||
decltype(a_scale_grid_desc_am_ak),
|
||||
decltype(a_scale_thread_copy),
|
||||
decltype(a_scale_grid_buf),
|
||||
decltype(a_scale_thread_buf),
|
||||
decltype(a_scale_thread_desc)>;
|
||||
|
||||
return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf};
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto
|
||||
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
|
||||
{
|
||||
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NumberOfBuffers>
|
||||
__device__ static auto
|
||||
MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id)
|
||||
{
|
||||
if constexpr(ck::is_same_v<BScaleType, void>)
|
||||
{
|
||||
using BScale = typename BlockwiseGemmPipe::Empty;
|
||||
return BScale{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(
|
||||
ScaleBlockK >=
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
|
||||
selected_wmma.k_per_wmma,
|
||||
"ScaleBlockK must be greater equal than KPerWmma");
|
||||
|
||||
const auto b_scale_grid_desc_bn_ak =
|
||||
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
|
||||
|
||||
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
|
||||
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
constexpr auto ScaleSliceSizeN =
|
||||
ScaleBlockN < NPerWmma ? NRepeat
|
||||
: math::integer_divide_ceil(NPerBlock, ScaleBlockN);
|
||||
constexpr auto ScaleSliceStrideN =
|
||||
math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN);
|
||||
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
|
||||
|
||||
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
|
||||
|
||||
auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma +
|
||||
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma) /
|
||||
ScaleBlockN;
|
||||
|
||||
constexpr index_t VectorDim =
|
||||
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 0 : 1;
|
||||
constexpr index_t VectorSize =
|
||||
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 1 : ScaleSliceSizeK;
|
||||
|
||||
auto b_scale_thread_copy =
|
||||
ThreadwiseTensorSliceTransfer_v2<BScaleType,
|
||||
BScaleType,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_desc),
|
||||
Sequence<1, ScaleSliceSizeK>,
|
||||
Sequence<0, 1>,
|
||||
VectorDim,
|
||||
VectorSize,
|
||||
1,
|
||||
true>(
|
||||
b_scale_grid_desc_bn_ak,
|
||||
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0));
|
||||
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
using BScale =
|
||||
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeN,
|
||||
ScaleSliceStrideN,
|
||||
ScaleSliceSizeK,
|
||||
NumberOfBuffers,
|
||||
1,
|
||||
decltype(b_scale_grid_desc_bn_ak),
|
||||
decltype(b_scale_thread_copy),
|
||||
decltype(b_scale_grid_buf),
|
||||
decltype(b_scale_thread_buf),
|
||||
decltype(b_scale_thread_desc)>;
|
||||
|
||||
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static index_t GetKBlockPerScale()
|
||||
{
|
||||
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
|
||||
if constexpr(ck::is_same_v<AScaleType, void> && ck::is_same_v<BScaleType, void>)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainKBlockLoop,
|
||||
@@ -539,18 +720,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
BsGridPointer& p_bs_grid,
|
||||
DsGridPointer& p_ds_grid,
|
||||
EDataType* p_e_grid,
|
||||
const AScaleType* p_a_scale_grid,
|
||||
const BScaleType* p_b_scale_grid,
|
||||
void* p_shared,
|
||||
const Problem& problem,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op,
|
||||
EpilogueArgument& epilogue_args)
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
{
|
||||
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
|
||||
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
|
||||
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
|
||||
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
|
||||
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
|
||||
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
|
||||
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
|
||||
@@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
|
||||
|
||||
// B Scale grid
|
||||
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
|
||||
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
|
||||
math::integer_divide_ceil(problem.K, ScaleBlockK)),
|
||||
make_tuple(problem.StrideScaleB, 1));
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
|
||||
|
||||
@@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
|
||||
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
|
||||
|
||||
// AScale struct
|
||||
auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id);
|
||||
|
||||
// BScale struct
|
||||
auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
|
||||
auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id);
|
||||
|
||||
const index_t num_k_block_per_scale = GetKBlockPerScale();
|
||||
|
||||
@@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
decltype(bs_grid_desc_bk0_n_bk1),
|
||||
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
|
||||
decltype(a_scale_struct),
|
||||
decltype(b_scale_struct),
|
||||
decltype(epilogue_args),
|
||||
HasMainKBlockLoop,
|
||||
@@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
block_m_id,
|
||||
block_n_id,
|
||||
num_k_block_per_scale,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
epilogue_args);
|
||||
epilogue_args,
|
||||
k_id);
|
||||
}
|
||||
|
||||
// NOTE: Wrapper function to have __global__ function in common
|
||||
@@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
__device__ static void Run(void* p_shared,
|
||||
const SplitKBatchOffset& splitk_batch_offset,
|
||||
Argument& karg,
|
||||
EpilogueArgument& epilogue_args)
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
{
|
||||
// shift A matrices pointer for splitk
|
||||
AsGridPointer p_as_grid_splitk;
|
||||
@@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
|
||||
splitk_batch_offset.b_k_split_offset[i];
|
||||
});
|
||||
|
||||
const AScaleType* p_a_scale_grid_ptr;
|
||||
if constexpr(ck::is_same_v<AScaleType, void>)
|
||||
{
|
||||
p_a_scale_grid_ptr = karg.p_a_scale_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset;
|
||||
}
|
||||
|
||||
const BScaleType* p_b_scale_grid_ptr;
|
||||
if constexpr(ck::is_same_v<BScaleType, void>)
|
||||
{
|
||||
p_b_scale_grid_ptr = karg.p_b_scale_grid;
|
||||
}
|
||||
else
|
||||
{
|
||||
p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset;
|
||||
}
|
||||
|
||||
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_as_grid_splitk,
|
||||
p_bs_grid_splitk,
|
||||
karg.p_ds_grid,
|
||||
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
|
||||
p_a_scale_grid_ptr,
|
||||
p_b_scale_grid_ptr,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
karg.b_element_op,
|
||||
karg.cde_element_op,
|
||||
epilogue_args);
|
||||
epilogue_args,
|
||||
k_id);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
bool HasMainKBlockLoop,
|
||||
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
|
||||
index_t MinimumOccupancy = 1,
|
||||
TailNumber TailNum = TailNumber::Full>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
#endif
|
||||
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
|
||||
{
|
||||
#if(defined(__gfx11__) || defined(__gfx12__))
|
||||
#if defined(__gfx11__)
|
||||
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
|
||||
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
|
||||
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
|
||||
(std::is_same_v<e_data_type, ck::half_t> ||
|
||||
std::is_same_v<e_data_type, ck::bhalf_t>)))
|
||||
{
|
||||
#endif
|
||||
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
|
||||
typename GridwiseGemm::EpilogueCShuffle>();
|
||||
__shared__ char p_shared[LDS_size];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
|
||||
|
||||
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
|
||||
const index_t k_id = blockIdx.z * num_k_per_block;
|
||||
|
||||
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
|
||||
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
|
||||
|
||||
#if defined(__gfx11__)
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
ignore = karg;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
@@ -162,7 +204,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
|
||||
static constexpr index_t KInner = IsBPreShuffled ? KInnerB : ck::math::min(KInnerA, KInnerB);
|
||||
|
||||
static constexpr index_t KPack =
|
||||
KInner *
|
||||
@@ -966,6 +1008,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
typename BGridDesc_BK0_N_K1,
|
||||
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename EpilogueArgument,
|
||||
bool HasMainKBlockLoop,
|
||||
@@ -988,6 +1031,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
const index_t& block_m_id,
|
||||
const index_t& block_n_id,
|
||||
const index_t& num_k_block_per_scale,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
EpilogueArgument& epilogue_args,
|
||||
const index_t k_id = 0)
|
||||
@@ -1072,6 +1116,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
b_block_buf,
|
||||
b_block_slice_copy_step,
|
||||
c_thread_buf,
|
||||
a_scale_struct,
|
||||
b_scale_struct,
|
||||
num_k_block_main_loop,
|
||||
num_k_block_per_scale);
|
||||
|
||||
@@ -43,13 +43,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
|
||||
{
|
||||
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
|
||||
|
||||
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
|
||||
karg.p_a_grid,
|
||||
karg.p_b_grid,
|
||||
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
|
||||
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
|
||||
karg.p_ds_grid,
|
||||
karg.p_c_grid,
|
||||
karg.p_a_scale_grid,
|
||||
karg.p_b_scale_grid,
|
||||
karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset,
|
||||
karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset,
|
||||
p_shared,
|
||||
karg,
|
||||
karg.a_element_op,
|
||||
@@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
|
||||
{
|
||||
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
|
||||
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
|
||||
}
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
|
||||
{
|
||||
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
|
||||
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
|
||||
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
|
||||
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
@@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
StrideB{StrideB_},
|
||||
StrideDs{StrideDs_},
|
||||
StrideC{StrideC_},
|
||||
StrideScaleA{StrideScaleA_},
|
||||
StrideScaleB{StrideScaleB_},
|
||||
KBatch{KBatch_},
|
||||
MPadded{CalculateMPadded(M_)},
|
||||
NPadded{CalculateNPadded(N_)},
|
||||
@@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t StrideB;
|
||||
std::array<index_t, NumDTensor> StrideDs;
|
||||
index_t StrideC;
|
||||
|
||||
index_t StrideScaleA;
|
||||
index_t StrideScaleB;
|
||||
index_t KBatch;
|
||||
index_t MPadded;
|
||||
index_t NPadded;
|
||||
@@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
index_t StrideB_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideC_,
|
||||
index_t StrideScaleA_,
|
||||
index_t StrideScaleB_,
|
||||
const AScaleType* p_a_scale_grid_,
|
||||
const BScaleType* p_b_scale_grid_,
|
||||
index_t k_batch_,
|
||||
AElementwiseOperation a_element_op_,
|
||||
BElementwiseOperation b_element_op_,
|
||||
CElementwiseOperation c_element_op_)
|
||||
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
|
||||
: Problem{M_,
|
||||
N_,
|
||||
K_,
|
||||
StrideA_,
|
||||
StrideB_,
|
||||
StrideDs_,
|
||||
StrideC_,
|
||||
StrideScaleA_,
|
||||
StrideScaleB_,
|
||||
k_batch_},
|
||||
p_a_grid{p_a_grid_},
|
||||
p_b_grid{p_b_grid_},
|
||||
p_ds_grid{},
|
||||
@@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
b_k_split_offset = blockIdx.z * karg.KRead;
|
||||
}
|
||||
|
||||
// Calculate A scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
|
||||
{
|
||||
scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
|
||||
{
|
||||
scale_a_k_split_offset =
|
||||
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
|
||||
}
|
||||
|
||||
// Calculate B scale offset
|
||||
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
scale_b_k_split_offset =
|
||||
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
|
||||
}
|
||||
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
|
||||
{
|
||||
scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
|
||||
}
|
||||
|
||||
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
|
||||
{
|
||||
karg.K = karg.KRead;
|
||||
@@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
|
||||
index_t a_k_split_offset;
|
||||
index_t b_k_split_offset;
|
||||
index_t scale_a_k_split_offset; // A scale matrix offset
|
||||
index_t scale_b_k_split_offset; // B scale matrix offset
|
||||
};
|
||||
|
||||
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
|
||||
@@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
|
||||
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
|
||||
|
||||
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
|
||||
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
|
||||
const auto a_scale_grid_desc_am_ak =
|
||||
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
|
||||
const auto b_scale_grid_desc_bn_ak =
|
||||
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
|
||||
|
||||
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
|
||||
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
|
||||
Reference in New Issue
Block a user