Dev/a8w4 and a8w8splitk (#3447)

* Ck moe bs splitk pr (#3440)

* splitk kick-off. Compilation fail

* splitk hack pass

* fix scale offset calc.

* clang-format for a8w8_moe_blk_gemm1 splitk change

* fix testcase error

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>

* Zan/moe a8w4 (#3441)

* update

* update

* update ck moe a8w4

* update

* update

* update

* compile pass

* update

* update

* python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready

* support new a8w4 kernel

* update

* update ck_tile

* re format

* update

* update

* fix conflict

* fix build

* update ck_tile moe

* fix clang format

* fix the problem

* fix accruacy issue

* fix

---------

Co-authored-by: oscar <huaiguxu@amd.com>
Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com>
Co-authored-by: Zzz9990 <zanzhang@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
yadaish
2025-12-19 09:26:52 +08:00
committed by GitHub
parent ba897f8435
commit c0ee71d735
13 changed files with 2911 additions and 139 deletions

View File

@@ -74,6 +74,7 @@ template <typename ALayout,
index_t ActivationOP = 0,
bool NSwizzle = false,
bool IsInputGemm = true,
bool IsSplitK = false,
bool MulRoutedWeight = false,
typename IndexType = index_t,
typename ComputeTypeA = CDataType,
@@ -156,6 +157,7 @@ struct DeviceMoeGemmBlockScale
ActivationOP,
NSwizzle,
IsInputGemm,
IsSplitK,
MulRoutedWeight,
IndexType,
ComputeTypeA,
@@ -201,12 +203,12 @@ struct DeviceMoeGemmBlockScale
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.M, arg.N * (IsInputGemm && IsSplitK ? 2 : 1), arg.K, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
index_t K_split = arg.KBatch == 1 ? arg.K : arg.KBatch * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto RunKernel = [&](const auto& kernel) {
@@ -249,11 +251,12 @@ struct DeviceMoeGemmBlockScale
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
// if(arg_.KBatch > 1)
// hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
// 0,
// arg_.M * arg_.N * sizeof(CDataType)
// * (IsInputGemm && IsSplitK ? 2 : 1),
// stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
@@ -267,11 +270,12 @@ struct DeviceMoeGemmBlockScale
}
else
{
if(arg.KBatch > 1)
hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
// if(arg.KBatch > 1)
// hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
// 0,
// arg.M * arg.N * sizeof(CDataType) *
// (IsInputGemm && IsSplitK ? 2 : 1),
// stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
@@ -289,8 +293,9 @@ struct DeviceMoeGemmBlockScale
constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
constexpr auto MemoryDataOp =
IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd;
constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK)
? InMemoryDataOperationEnum::Set
: InMemoryDataOperationEnum::AtomicAdd;
if(has_main_k_block_loop)
{
@@ -416,8 +421,8 @@ struct DeviceMoeGemmBlockScale
static bool IsSupportedArgument(const Argument& arg)
{
// only impl kbatch 1 now
if(arg.KBatch > 1)
// only impl kbatch 1 for fp32
if(arg.KBatch > 1 && !std::is_same_v<CDataType, float>)
{
return false;
}
@@ -441,6 +446,11 @@ struct DeviceMoeGemmBlockScale
{
return false;
}
if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0)
{
// Not support Kpadding with KBatch > 1
return false;
}
if(get_warp_size() == 64)
{