mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297)
* [fix] align v3 gufusion pipeline * fix device kernel selection. * Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE * experimental optimization for scale load in blkscale gemm * Add asm for no-loop v3_128x128x128 * fix bugs * tune fp8 example * Update v1_128x128x128 to 2x2 instead of 4x1 * wip * add warmup to asm launch * wip2 * 16x16 function merged to moe * temp save, a performant version. * wip3 * Update .co binary to 16x16 * 16x16x128 correct; 64x64x128 failed * update * use mem_op::set when topk=1 * add mx fp8 b_preshuffle support, function not yet tested. * Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints * some fixes * fix update * remove some unnecessary hacky; enable 256x256x256 tilesize * update for function debug * Add pipeline v3. Have some runtime issue and register spill * Fix pipe v3 correctness issue * remove unnecessary hacky * clang format * fix a bug * fix the bug, functional test passed * tempsave; buggy at passed 4 e8m0 to scaled mfma * added fp4_bpreshuffle example, build failures * fixed some bugs * implement shuffled scale mxfp4gemm, blocker: opsel not effect * hotfix * fix bugs, build passed * (M, N, K)=(128, 128, 128) function failed. * temp save for gemm1. Function not ready * fix compile error. Gemm2 pass. Gemm1 WIP * fix bug for a lds read * update moe * Compile pass. Gemm1 function WIP * update moe * fix fp8; fix even/odd * tempsave * update moe * Revert "update" This reverts commit960b2bce1c. * Revert "use mem_op::set when topk=1" This reverts commitdef952a178. * Add v3 128x128x128_4x4_16x16.co for gfx950 * temp cmake flag suppression for aiter test * add code for mxfp4 gemm, blockscale not supported yet * gemm1 up-only pass. GU WIP * function pass with inline asm hacky * revert unexpected file change * updated and build passed * update CE elementOP * added code for debug * Gemm1 GUFusion function pass. Perf WIP * Fix fp8/bf8; remove duplicated code * disable the scheduler in v3; bring it back when compiler feature ready. * update moe v1 pipeline * Add gemm1 v1 32x128x128 * remove schedule barrier * updated * Fix fp8/bf8 B-row * mfma using asm, device result correct, host result need to check * gemm1 v3 64x128x128 debug * fix cpu ref * a/b thread_desc stride fix * Use random scale for init1 * 16x16x128 input size blockscale function passed * fix blockscale gemm bug * tempsave. Almost all instances passed. * v1 fix for mi350. * temp save * debug save * update debug * fix the bug, 128x128x256 tile function passed * v3 * rename moe block selector and pipeline * Add gemm1 v1 * Add gemm1 v1 to selector * added mx moe block v3 support, function passed * compile error fix * Improve the pipeline * Pack e8m0 as int32_t * v1 compile pass. Function not ready * debug synchronize issue over different GPU/ROCm * minor fix * Add profiler filter * Add f4 ckProfiler * Fix example compile error * Add f4 profiler examples * tempsave * v1 function pass. * v3 function pass * align file and function name * mx_moe_fp4 ready for aiter with clang-format. * modify the way we represent fp4 * generalize the pipeline scheduling. * init moe mx f4 scale shuffle * Cmakelist diable compiler-bound flags * mx_fp4 default parameter change * Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler. * update code * tempsave; modify the way we represent fp4 * generalize the pipeline scheduling. * Add gemm1 gfx942 .co support * updated code, build passed. * Update gemm2 asm with latest compiler flag * Fix mx f4 ckProfiler * Fix blockwise gemm mx v1 * lds conflict free + buffer load lds * Add gemm2 v3 64x128x128 * fix a, b scale loading bugs, a, b scale loading now correctly * Add gemm2 v3 64x128x128 * commit with debug info * fix fp4 profiler * Add mx fp4 pileline v1 instances * Fix v2 topk_weight cal. Add silu asm. * v2 tok_weight WIP * init mx fp4 B no preshuffle version * tempsave. compile pass, function wrong * enable fp4 moe no weigth preshuffle, function pass * update the TFlops calculation in the example * Add gemm2 64x128x128 asm. Fix BF16 ref. * fix 2 typos in fp4_preshuffle * Better kernel selection in device classes * correct preShuffleBuffer we should used packed k to do shuffle. * lds conflict free + buffer load lds * optimize offset math in dma * Fix fp4 ckProfiler * Fix MX MFMA tests * fix f4 pipeline issues * gemm1 func pass * update mx moe gemm1_bns tile size to 64x128x256 * update mx moe gemm1 gemm2 TF and BW calculation * fix typo * temp save * Fix example_gemm_mx build * rename the block pipeline * correct a typo in tail * Add rotating to mx examples * fix the correctness issue * Fix v1; use M padding * Add NT flag to B/BScale buffer * Merge gemm_mx_common.hpp * temp save, 4.4~4.5 * Fix 'Merge gemm_mx_common.hpp' * refactor the pipeline * Pad the M for scale buffer unconditionaly * update MX moe GEMM1 hotloopscheduling * change the gemm1 tile from 64x128x128 to 128x64x128 * Unconditional Ascale padding * Pad shuffled a scale only * pad ascale * add vmcnt guard for async copy * Profiler add f4 wp * Merge preshuffle device * Add more fp4 wp instances * Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format * Clang-format after 2 merges * Remove rocm6.3 workaround flags and macro * Fix fp8 config * Fix bf8 config * flag and barrier fix for copmiler branch MainOpSelV3 * Add fp8 profiler instances * Remove debug infos; Enable flags for blockscale f8 * No asm ver. for merging moe blocksale fp8 into mainline * update the flag name for f8blockscale * recover example * fix performance bug of bpreshuffle f8 gemm * clang format, remove single rate mfma restriction for f8 * remove single rate mfma restriction for f8 blockscale gemm * Fix moe blockscale gemm1 barrier 0x800 for new compiler * add pipeline v1 for MOE Gemm2 * Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns * Fix OOB; add MB96 instances * remove unnecessary files * fix the cmake issue * Enable splitk for mxfp4; clang format; * Generate random tensor values with multiple threads * Use packed_size_v for A/BPackedSize * Fix warning * Fix target_compile_options for disabled target on gfx942 * fix moe pki4 on gfx950 * doc the kGroup definition * Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale) * Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example * Fix unknown compiler flag * fix two failed examples. * fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16 * workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major * fix test_gemm_splitk; * Fix compile for mx_mfma_op * add mfma selection logic for multipled_v3 * Clean up * Fix device gemm mx link error * improve the global atomic pattern * Revert unnecessary copyright updates * restore minimum_occupancy logic * Avoid data race in moe gemm2 ref * Build fp8 gemm_multiply_multiply and moe only on gfx94/95 * update the instance in device_mx_gemm * Resolve comments * Copyright 2025 * Remove unused code * fix library linking issue --------- Co-authored-by: OscarXu <huaiguxu@amd.com> Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com> Co-authored-by: mtgu0705 <mtgu@amd.com> Co-authored-by: aska-0096 <haocwang@amd.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: valarLip <340077269@qq.com> Co-authored-by: feifei14119 <feiw@amd.com> Co-authored-by: Lin, Qun <qlin@amd.com> Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#ifndef __HIPCC_RTC__
|
||||
@@ -149,6 +149,52 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
|
||||
#endif
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockSize,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
#ifndef CK_CODE_GEN_RTC
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
ck::index_t M,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t StrideA,
|
||||
ck::index_t StrideAScale,
|
||||
ck::index_t StrideB,
|
||||
ck::index_t StrideBScale,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
ck::index_t StrideE,
|
||||
ck::index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -60,6 +60,49 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename AScaleType,
|
||||
typename BDataType,
|
||||
typename BScaleType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation>
|
||||
struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_e,
|
||||
const ck::index_t M,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t StrideA,
|
||||
const ck::index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const ck::index_t StrideE,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CDEElementwiseOperation cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
|
||||
virtual int GetPreShuffleParameters() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
@@ -0,0 +1,507 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto Run = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
// unconditional 2 to remove agpr usage
|
||||
constexpr index_t minimum_occupancy = 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds<
|
||||
GridwiseGemm,
|
||||
true,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
Run(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
|
||||
GridwiseGemm,
|
||||
false,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
Run(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
|
||||
// KPerBlock)
|
||||
// {
|
||||
// return false;
|
||||
// }
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Padding to release this restriction
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
const index_t M,
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t StrideA,
|
||||
const index_t StrideB,
|
||||
const std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
const index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdlUniversal"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,584 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t BlockSize,
|
||||
index_t ScaleBlockM,
|
||||
index_t ScaleBlockN,
|
||||
index_t ScaleBlockK,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = false,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = CDataType,
|
||||
typename ComputeTypeB = ComputeTypeA,
|
||||
typename LDSTypeA = ComputeTypeA,
|
||||
typename LDSTypeB = ComputeTypeB>
|
||||
struct DeviceMoeGemmBlockScale
|
||||
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm = GridwiseMoeGemmBlockScale<
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
BlockSize,
|
||||
ScaleBlockM,
|
||||
ScaleBlockN,
|
||||
ScaleBlockK,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
LDSTypeA,
|
||||
LDSTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
static constexpr index_t BPackedSize = []() {
|
||||
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
|
||||
return 2;
|
||||
else
|
||||
return 1;
|
||||
}();
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
|
||||
4 * (1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
|
||||
4 * (2) * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v2 support now");
|
||||
}
|
||||
}
|
||||
#if 1
|
||||
else
|
||||
{
|
||||
// Tail number always 1
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_b,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
const void* p_a_scale,
|
||||
const void* p_b_scale,
|
||||
// index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
1, // KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmm"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,571 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemmMX<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
ScaleBlockSize,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
|
||||
sizeof(ADataType) / APackedSize;
|
||||
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
|
||||
sizeof(BDataType) / BPackedSize;
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
|
||||
APackedSize / BlockSize / 4 *
|
||||
(1 + GridwiseGemm::NWave);
|
||||
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) /
|
||||
BPackedSize / BlockSize / 4 * (2) *
|
||||
(IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
|
||||
BlockSize / 4 * (IsInputGemm ? 2 : 1);
|
||||
constexpr auto estimated_reg_total =
|
||||
estimated_reg_a + estimated_reg_b + estimated_reg_c;
|
||||
|
||||
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
|
||||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v3 support now");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmmMx"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,540 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/flush_cache.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename AScaleDataType,
|
||||
typename BDataType,
|
||||
typename BScaleDataType,
|
||||
typename DsDataType,
|
||||
typename CDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
GemmSpecialization GemmSpec,
|
||||
index_t ScaleBlockSize,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t AK1,
|
||||
index_t BK1,
|
||||
index_t MPerXDL,
|
||||
index_t NPerXDL,
|
||||
index_t MXdlPerWave,
|
||||
index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
|
||||
index_t ActivationOP = 0,
|
||||
bool NSwizzle = false,
|
||||
bool IsInputGemm = true,
|
||||
bool MulRoutedWeight = true,
|
||||
typename IndexType = index_t,
|
||||
typename ComputeTypeA = ADataType,
|
||||
typename ComputeTypeB = BDataType>
|
||||
struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
ScaleBlockSize,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation>
|
||||
{
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
using GridwiseGemm =
|
||||
GridwiseMoeGemmMXBNS<ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ADataType,
|
||||
AScaleDataType,
|
||||
BDataType,
|
||||
BScaleDataType,
|
||||
GemmAccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
CDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
GemmSpec,
|
||||
ScaleBlockSize,
|
||||
BlockSize,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
AK1,
|
||||
BK1,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVectors,
|
||||
BlkGemmPipeSched,
|
||||
BlkGemmPipelineVer,
|
||||
ActivationOP,
|
||||
NSwizzle,
|
||||
IsInputGemm,
|
||||
MulRoutedWeight,
|
||||
IndexType,
|
||||
ComputeTypeA,
|
||||
ComputeTypeB>;
|
||||
|
||||
using Argument = typename GridwiseGemm::Argument;
|
||||
|
||||
static constexpr index_t APackedSize = packed_size_v<ADataType>;
|
||||
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
|
||||
|
||||
int GetPreShuffleParameters() override { return NPerXDL; }
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
|
||||
}
|
||||
|
||||
index_t gdx, gdy, gdz;
|
||||
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
index_t k_grain = arg.KBatch * KPerBlock;
|
||||
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
|
||||
|
||||
const auto RunKernel = [&](const auto& kernel) {
|
||||
if(stream_config.flush_cache)
|
||||
{
|
||||
|
||||
std::array<std::size_t, NumDTensor> DsSize;
|
||||
|
||||
Argument arg_ = arg;
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
|
||||
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
|
||||
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
|
||||
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
|
||||
|
||||
auto size_a_buffer =
|
||||
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
|
||||
auto size_b_buffer =
|
||||
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
|
||||
|
||||
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
|
||||
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
|
||||
});
|
||||
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
|
||||
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck::utility::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(arg_.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
|
||||
0,
|
||||
arg_.M * arg_.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
};
|
||||
|
||||
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
|
||||
stream_config,
|
||||
run_flush_cache,
|
||||
kernel,
|
||||
dim3(gdx, gdy, gdz),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg_);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(arg.KBatch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
|
||||
0,
|
||||
arg.M * arg.N * sizeof(CDataType),
|
||||
stream_config.stream_id_));
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: Check if this is the right algorithm for minimum_occupancy
|
||||
constexpr index_t minimum_occupancy =
|
||||
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
|
||||
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
|
||||
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
|
||||
? 2
|
||||
: 1
|
||||
: 2;
|
||||
|
||||
constexpr auto MemoryDataOp =
|
||||
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
// Tail number always full
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
true,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("todo: only v1 & v3 support now");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Full>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
{
|
||||
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Odd>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
|
||||
false,
|
||||
MemoryDataOp,
|
||||
minimum_occupancy,
|
||||
TailNumber::Even>;
|
||||
RunKernel(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// only impl kbatch 1 now
|
||||
if(arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
|
||||
GemmSpec == GemmSpecialization::NKPadding ||
|
||||
GemmSpec == GemmSpecialization::MNKPadding ||
|
||||
GemmSpec == GemmSpecialization::KPadding))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return GridwiseGemm::CheckValidity(arg);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const void* p_sorted_token_ids,
|
||||
const void* p_sorted_expert_ids,
|
||||
const void* p_max_token_id,
|
||||
const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t NumTokens,
|
||||
index_t TopK,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op)
|
||||
{
|
||||
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
|
||||
static_cast<const index_t*>(p_sorted_expert_ids),
|
||||
static_cast<const index_t*>(p_max_token_id),
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
NumTokens,
|
||||
TopK,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
|
||||
const void* p_a_scale,
|
||||
const void* p_b,
|
||||
const void* p_b_scale,
|
||||
std::array<const void*, NumDTensor> p_ds,
|
||||
void* p_c,
|
||||
index_t M,
|
||||
index_t N,
|
||||
index_t K,
|
||||
index_t StrideA,
|
||||
index_t StrideScaleA,
|
||||
index_t StrideB,
|
||||
index_t StrideScaleB,
|
||||
std::array<ck::index_t, NumDTensor> StrideDs,
|
||||
index_t StrideC,
|
||||
index_t KBatch,
|
||||
AElementwiseOperation a_element_op,
|
||||
BElementwiseOperation b_element_op,
|
||||
CElementwiseOperation c_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
static_cast<const ADataType*>(p_a),
|
||||
static_cast<const AScaleDataType*>(p_a_scale),
|
||||
static_cast<const BDataType*>(p_b),
|
||||
static_cast<const BScaleDataType*>(p_b_scale),
|
||||
p_ds,
|
||||
static_cast<CDataType*>(p_c),
|
||||
M, // randoms set, no use
|
||||
0,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
StrideA,
|
||||
StrideScaleA,
|
||||
StrideB,
|
||||
StrideScaleB,
|
||||
StrideDs,
|
||||
StrideC,
|
||||
KBatch,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
|
||||
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
|
||||
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
|
||||
|
||||
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
|
||||
{BlockGemmPipelineVersion::v1, "v1"},
|
||||
{BlockGemmPipelineVersion::v2, "v2"},
|
||||
{BlockGemmPipelineVersion::v3, "v3"},
|
||||
{BlockGemmPipelineVersion::v4, "v4"},
|
||||
{BlockGemmPipelineVersion::v5, "v5"}};
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceMoeGEmmMx"
|
||||
<< "<"
|
||||
<< getGemmSpecializationString(GemmSpec) << ", "
|
||||
<< std::string(ALayout::name)[0]
|
||||
<< std::string(BLayout::name)[0]
|
||||
<< std::string(CLayout::name)[0]
|
||||
<< ">"
|
||||
<< " BlkSize: "
|
||||
<< BlockSize << ", "
|
||||
<< "BlkTile: "
|
||||
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
|
||||
<< "WaveTile: "
|
||||
<< MPerXDL<<"x"<<NPerXDL << ", "
|
||||
<< "WaveMap: "
|
||||
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
|
||||
<< "VmemReadVec: "
|
||||
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
|
||||
<< "BlkGemmPipelineScheduler: "
|
||||
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
|
||||
<< "BlkGemmPipelineVersion: "
|
||||
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
|
||||
<< "BlkGemmPipelinePrefetchStages: "
|
||||
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user