Add MoE & FP8 Blockscale WP Kernels for GFX950 (#2297)

* [fix] align v3 gufusion pipeline

* fix device kernel selection.

* Add .co direct asm support by CK_USE_ASM_MOE_STAGE2_BLOCKSCALE

* experimental optimization for scale load in blkscale gemm

* Add asm for no-loop v3_128x128x128

* fix bugs

* tune fp8 example

* Update v1_128x128x128 to 2x2 instead of 4x1

* wip

* add warmup to asm launch

* wip2

* 16x16 function merged to moe

* temp save, a performant version.

* wip3

* Update .co binary to 16x16

* 16x16x128 correct; 64x64x128 failed

* update

* use mem_op::set when topk=1

* add mx fp8 b_preshuffle support, function not yet tested.

* Spilt the fp4 target. Fix the known bugs. 128x128x128 sanity checked; remove prints

* some fixes

* fix update

* remove some unnecessary hacky; enable 256x256x256 tilesize

* update for function debug

* Add pipeline v3. Have some runtime issue and register spill

* Fix pipe v3 correctness issue

* remove unnecessary hacky

* clang format

* fix a bug

* fix the bug, functional test passed

* tempsave; buggy at passed 4 e8m0 to scaled mfma

* added fp4_bpreshuffle example, build failures

* fixed some bugs

* implement shuffled scale mxfp4gemm, blocker: opsel not effect

* hotfix

* fix bugs, build passed

* (M, N, K)=(128, 128, 128) function failed.

* temp save for gemm1. Function not ready

* fix compile error. Gemm2 pass. Gemm1 WIP

* fix bug for a lds read

* update moe

* Compile pass. Gemm1 function WIP

* update moe

* fix fp8; fix even/odd

* tempsave

* update moe

* Revert "update"

This reverts commit 960b2bce1c.

* Revert "use mem_op::set when topk=1"

This reverts commit def952a178.

* Add v3 128x128x128_4x4_16x16.co for gfx950

* temp cmake flag suppression  for aiter test

* add code for mxfp4 gemm, blockscale not supported yet

* gemm1 up-only pass. GU WIP

* function pass with inline asm hacky

* revert unexpected file change

* updated and build passed

* update CE elementOP

* added code for debug

* Gemm1 GUFusion function pass. Perf WIP

* Fix fp8/bf8; remove duplicated code

* disable the scheduler in v3; bring it back when compiler feature ready.

* update moe v1 pipeline

* Add gemm1 v1 32x128x128

* remove schedule barrier

* updated

* Fix fp8/bf8 B-row

* mfma using asm, device result correct, host result need to check

* gemm1 v3 64x128x128 debug

* fix cpu ref

* a/b thread_desc stride fix

* Use random scale for init1

* 16x16x128 input size blockscale function passed

* fix blockscale gemm bug

* tempsave. Almost all instances passed.

* v1 fix for mi350.

* temp save

* debug save

* update debug

* fix the bug, 128x128x256 tile function passed

* v3

* rename moe block selector and pipeline

* Add gemm1 v1

* Add gemm1 v1 to selector

* added mx moe block v3 support, function passed

* compile error fix

* Improve the pipeline

* Pack e8m0 as int32_t

* v1 compile pass. Function not ready

* debug synchronize issue over different GPU/ROCm

* minor fix

* Add profiler filter

* Add f4 ckProfiler

* Fix example compile error

* Add f4 profiler examples

* tempsave

* v1 function pass.

* v3 function pass

* align file and function name

* mx_moe_fp4 ready for aiter with clang-format.

* modify the way we represent fp4

* generalize the pipeline scheduling.

* init moe mx f4 scale shuffle

* Cmakelist diable compiler-bound flags

* mx_fp4 default parameter change

* Moe blockscale gemm1&gemm2 asm support for aiter. Suppression cmkae flag til new compler.

* update code

* tempsave; modify the way we represent fp4

* generalize the pipeline scheduling.

* Add gemm1 gfx942 .co support

* updated code, build passed.

* Update gemm2 asm with latest compiler flag

* Fix mx f4 ckProfiler

* Fix blockwise gemm mx v1

* lds conflict free + buffer load lds

* Add gemm2 v3 64x128x128

* fix a, b scale loading bugs, a, b scale loading now correctly

* Add gemm2 v3 64x128x128

* commit with debug info

* fix fp4 profiler

* Add mx fp4 pileline v1 instances

* Fix v2 topk_weight cal. Add silu asm.

* v2 tok_weight WIP

* init mx fp4 B no preshuffle version

* tempsave. compile pass, function wrong

* enable fp4 moe no weigth preshuffle, function pass

* update the TFlops calculation in the example

* Add gemm2 64x128x128 asm. Fix BF16 ref.

* fix 2 typos in fp4_preshuffle

* Better kernel selection in device classes

* correct preShuffleBuffer

we should used packed k to do shuffle.

* lds conflict free + buffer load lds

* optimize offset math in dma

* Fix fp4 ckProfiler

* Fix MX MFMA tests

* fix f4 pipeline issues

* gemm1 func pass

* update mx moe gemm1_bns tile size to 64x128x256

* update mx moe gemm1 gemm2 TF and BW calculation

* fix typo

* temp save

* Fix example_gemm_mx build

* rename the block pipeline

* correct a typo in tail

* Add rotating to mx examples

* fix the correctness issue

* Fix v1; use M padding

* Add NT flag to B/BScale buffer

* Merge gemm_mx_common.hpp

* temp save, 4.4~4.5

* Fix 'Merge gemm_mx_common.hpp'

* refactor the pipeline

* Pad the M for scale buffer unconditionaly

* update MX moe GEMM1 hotloopscheduling

* change the gemm1 tile from 64x128x128 to 128x64x128

* Unconditional Ascale padding

* Pad shuffled a scale only

* pad ascale

* add vmcnt guard for async copy

* Profiler add f4 wp

* Merge preshuffle device

* Add more fp4 wp instances

* Fix do_weight in gemm1. Fix cshuffle_datatype. Clang-format

* Clang-format after 2 merges

* Remove rocm6.3 workaround flags and macro

* Fix fp8 config

* Fix bf8 config

* flag and barrier fix for copmiler branch MainOpSelV3

* Add fp8 profiler instances

* Remove debug infos; Enable flags for blockscale f8

* No asm ver. for merging moe blocksale fp8 into mainline

* update the flag name for f8blockscale

* recover example

* fix performance bug of bpreshuffle f8 gemm

* clang format, remove  single rate mfma restriction for f8

* remove single rate mfma restriction for f8 blockscale gemm

* Fix moe blockscale gemm1 barrier 0x800 for new compiler

* add pipeline v1 for MOE Gemm2

* Use v1 pipeline for example_moe_gemm2_xdl_mx_fp4_bns

* Fix OOB; add MB96 instances

* remove unnecessary files

* fix the cmake issue

* Enable splitk for mxfp4; clang format;

* Generate random tensor values with multiple threads

* Use packed_size_v for A/BPackedSize

* Fix warning

* Fix target_compile_options for disabled target on gfx942

* fix moe pki4 on gfx950

* doc the kGroup definition

* Fix ThreadwiseTensorSliceTransfer_v4::Run (Fuse scale)

* Refactor thread_copy_lds_direct_load; fix gfx942 direct lds load example; fix f16_pki4 example

* Fix unknown compiler flag

* fix two failed examples.

* fix some failure tile size in gfx950 universal gemm. fix test_gemm_fp16

* workaround fix for test_gemm_f32; * We have very limited support for lds direct load if input matrix is not K major

* fix test_gemm_splitk;

* Fix compile for mx_mfma_op

* add mfma selection logic for multipled_v3

* Clean up

* Fix device gemm mx link error

* improve the global atomic pattern

* Revert unnecessary copyright updates

* restore minimum_occupancy logic

* Avoid data race in moe gemm2 ref

* Build fp8 gemm_multiply_multiply and moe only on gfx94/95

* update the instance in device_mx_gemm

* Resolve comments

* Copyright 2025

* Remove unused code

* fix library linking issue

---------

Co-authored-by: OscarXu <huaiguxu@amd.com>
Co-authored-by: lalala-sh <Jiaxing.Wen@amd.com>
Co-authored-by: mtgu0705 <mtgu@amd.com>
Co-authored-by: aska-0096 <haocwang@amd.com>
Co-authored-by: Your Name <you@example.com>
Co-authored-by: valarLip <340077269@qq.com>
Co-authored-by: feifei14119 <feiw@amd.com>
Co-authored-by: Lin, Qun <qlin@amd.com>
Co-authored-by: Andriy Roshchenko <andriy.roshchenko@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
Yi DING
2025-06-12 09:25:59 +08:00
committed by GitHub
parent 8c1ed6f4c1
commit 37554c31e8
85 changed files with 32508 additions and 431 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef __HIPCC_RTC__
@@ -149,6 +149,52 @@ struct DeviceGemmMultipleDSplitKBPreShuffle : public BaseOperator
#endif
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockSize,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceMoEGemmMXBPreShuffle : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef CK_CODE_GEN_RTC
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
ck::index_t M,
ck::index_t N,
ck::index_t K,
ck::index_t StrideA,
ck::index_t StrideAScale,
ck::index_t StrideB,
ck::index_t StrideBScale,
std::array<ck::index_t, NumDTensor> StrideDs,
ck::index_t StrideE,
ck::index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
#endif
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -60,6 +60,49 @@ struct DeviceGemmMultipleD_ABScale : public BaseOperator
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,507 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceGemmMultiD_BlockScale_Xdl_CShuffle_V3_BPreshuffle
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm
using GridwiseGemm = GridwiseGemmMultiD_blockscale_xdl_cshuffle_v3_b_preshuffle<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
// unconditional 2 to remove agpr usage
constexpr index_t minimum_occupancy = 2;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle_2lds<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3_multi_d_blockscale_b_preshuffle<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_xdl_supported())
{
return false;
}
// if(ScaleBlockM % MPerBlock != 0 || ScaleBlockN % NPerBlock != 0 || ScaleBlockK !=
// KPerBlock)
// {
// return false;
// }
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
// Padding to release this restriction
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
const index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
const index_t M,
const index_t N,
const index_t K,
const index_t StrideA,
const index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"}};
// clang-format off
str << "DeviceGemmXdlUniversal"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,584 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <hip/hip_runtime.h>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ComputeTypeA,
typename LDSTypeB = ComputeTypeB>
struct DeviceMoeGemmBlockScale
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm = GridwiseMoeGemmBlockScale<
ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
BDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB,
LDSTypeA,
LDSTypeB>;
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<ADataType>, pk_i4_t>)
return 2;
else
return 1;
}();
static constexpr index_t BPackedSize = []() {
if constexpr(is_same_v<remove_cvref_t<BDataType>, pk_i4_t>)
return 2;
else
return 1;
}();
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
4 * (1 + GridwiseGemm::NWave);
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize /
4 * (2) * (IsInputGemm ? 2 : 1);
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
BlockSize / 4 * (IsInputGemm ? 2 : 1);
constexpr auto estimated_reg_total =
estimated_reg_a + estimated_reg_b + estimated_reg_c;
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else
{
throw std::runtime_error("todo: only v1 & v2 support now");
}
}
#if 1
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_gemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
#endif
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids,
const void* p_max_token_id,
const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t NumTokens,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const index_t*>(p_max_token_id),
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
NumTokens,
TopK,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
// index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(nullptr,
nullptr,
nullptr,
static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
static_cast<CDataType*>(p_c),
M, // randoms set, no use
0,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideC,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1, // KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"}};
// clang-format off
str << "DeviceMoeGEmm"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,571 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t ScaleBlockSize,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
struct DeviceMoeGemmMX : public DeviceMoEGemmMXBPreShuffle<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockSize,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm =
GridwiseMoeGemmMX<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) /
APackedSize / BlockSize / 4 *
(1 + GridwiseGemm::NWave);
constexpr auto estimated_reg_b = NPerBlock * KPerBlock * sizeof(BDataType) /
BPackedSize / BlockSize / 4 * (2) *
(IsInputGemm ? 2 : 1);
constexpr auto estimated_reg_c = MPerBlock * NPerBlock * sizeof(GemmAccDataType) /
BlockSize / 4 * (IsInputGemm ? 2 : 1);
constexpr auto estimated_reg_total =
estimated_reg_a + estimated_reg_b + estimated_reg_c;
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else
{
throw std::runtime_error("todo: only v1 & v3 support now");
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm_2lds<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids,
const void* p_max_token_id,
const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t NumTokens,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const index_t*>(p_max_token_id),
static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
p_ds,
static_cast<CDataType*>(p_c),
NumTokens,
TopK,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(nullptr,
nullptr,
nullptr,
static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
p_ds,
static_cast<CDataType*>(p_c),
M, // randoms set, no use
0,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceMoeGEmmMx"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,540 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_moe_mx_gemm_bns.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename GemmAccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t ScaleBlockSize,
index_t BlockSize,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerXDL,
index_t NPerXDL,
index_t MXdlPerWave,
index_t NXdlPerWave,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool MulRoutedWeight = true,
typename IndexType = index_t,
typename ComputeTypeA = ADataType,
typename ComputeTypeB = BDataType>
struct DeviceMoeGemmMXBNS : public DeviceMoEGemmMXBPreShuffle<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockSize,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using GridwiseGemm =
GridwiseMoeGemmMXBNS<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
GemmAccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
ScaleBlockSize,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerXDL,
NPerXDL,
MXdlPerWave,
NXdlPerWave,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMXdlPerWavePerShuffle,
CShuffleNXdlPerWavePerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CDEShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ActivationOP,
NSwizzle,
IsInputGemm,
MulRoutedWeight,
IndexType,
ComputeTypeA,
ComputeTypeB>;
using Argument = typename GridwiseGemm::Argument;
static constexpr index_t APackedSize = packed_size_v<ADataType>;
static constexpr index_t BPackedSize = packed_size_v<BDataType>;
int GetPreShuffleParameters() override { return NPerXDL; }
// Invoker
struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
std::array<std::size_t, NumDTensor> DsSize;
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer =
a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
auto size_b_buffer =
b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
DsSize[i] = ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiD<Argument, DsDataType> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer, DsSize);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
// TODO: Check if this is the right algorithm for minimum_occupancy
constexpr index_t minimum_occupancy =
BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
? 2
: 1
: 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Full>;
RunKernel(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
true,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
else
{
throw std::runtime_error("todo: only v1 & v3 support now");
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Full>;
RunKernel(kernel);
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Odd>;
RunKernel(kernel);
}
else
{
const auto kernel = kernel_moe_mxgemm<GridwiseGemm,
false,
MemoryDataOp,
minimum_occupancy,
TailNumber::Even>;
RunKernel(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
{
return false;
}
if(!ck::is_xdl_supported())
{
return false;
}
if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
{
return false;
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_sorted_token_ids,
const void* p_sorted_expert_ids,
const void* p_max_token_id,
const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t NumTokens,
index_t TopK,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{static_cast<const index_t*>(p_sorted_token_ids),
static_cast<const index_t*>(p_sorted_expert_ids),
static_cast<const index_t*>(p_max_token_id),
static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
p_ds,
static_cast<CDataType*>(p_c),
NumTokens,
TopK,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_a_scale,
const void* p_b,
const void* p_b_scale,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideScaleA,
index_t StrideB,
index_t StrideScaleB,
std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(nullptr,
nullptr,
nullptr,
static_cast<const ADataType*>(p_a),
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BDataType*>(p_b),
static_cast<const BScaleDataType*>(p_b_scale),
p_ds,
static_cast<CDataType*>(p_c),
M, // randoms set, no use
0,
M,
N,
K,
StrideA,
StrideScaleA,
StrideB,
StrideScaleB,
StrideDs,
StrideC,
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceMoeGEmmMx"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerXDL<<"x"<<NPerXDL << ", "
<< "WaveMap: "
<< MXdlPerWave<<"x" << NXdlPerWave<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck