From e76ee195dfa42262eb50ef93df2cf8fe033cbf8d Mon Sep 17 00:00:00 2001 From: yadaish Date: Fri, 19 Dec 2025 09:26:52 +0800 Subject: [PATCH] Dev/a8w4 and a8w8splitk (#3447) * Ck moe bs splitk pr (#3440) * splitk kick-off. Compilation fail * splitk hack pass * fix scale offset calc. * clang-format for a8w8_moe_blk_gemm1 splitk change * fix testcase error --------- Co-authored-by: oscar Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> * Zan/moe a8w4 (#3441) * update * update * update ck moe a8w4 * update * update * update * compile pass * update * update * python3 op_tests/test_moe_2stage.py -t 16 -e 1 -k 1 -dim 256,256 ready * support new a8w4 kernel * update * update ck_tile * re format * update * update * fix conflict * fix build * update ck_tile moe * fix clang format * fix the problem * fix accruacy issue * fix --------- Co-authored-by: oscar Co-authored-by: huaiguxu <145733371+huaiguxu@users.noreply.github.com> Co-authored-by: Zzz9990 Co-authored-by: felix [ROCm/composable_kernel commit: c0ee71d73527cd8206038b86b6eeb4fcf955154e] --- .../65_gemm_multiply_multiply/CMakeLists.txt | 1 + .../moe_gemm1_xdl_fp8_blockscale.cpp | 4 +- .../moe_gemm1_xdl_fp8_blockscale_splitk.cpp | 539 ++++++++ .../moe_gemm2_xdl_fp8_blockscale.cpp | 4 +- ..._xdlops_moe_blockscale_b_preshuffle_v1.hpp | 3 + .../impl/device_moe_gemm_blockscale.hpp | 44 +- .../gpu/grid/gridwise_moe_gemm_blockscale.hpp | 125 +- .../core/tensor/tile_scatter_gather.hpp | 183 ++- .../ops/flatmm/kernel/moe_flatmm_kernel.hpp | 215 ++- ...ec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp | 1170 +++++++++++++++++ ...mm_pipeline_agmem_bgmem_creg_v1_policy.hpp | 511 ++++++- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 19 - .../reference_moe_gemm1_blockscale_splitk.hpp | 232 ++++ 13 files changed, 2911 insertions(+), 139 deletions(-) create mode 100644 example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp create mode 100644 library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 944a8f96bf..24a4106ae7 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -18,6 +18,7 @@ add_example_executable(example_moe_gemm1_xdl_fp8 moe_gemm1_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8 moe_gemm2_xdl_fp8.cpp) add_example_executable(example_moe_gemm2_xdl_fp8_blockscale moe_gemm2_xdl_fp8_blockscale.cpp) add_example_executable(example_moe_gemm1_xdl_fp8_blockscale moe_gemm1_xdl_fp8_blockscale.cpp) +add_example_executable(example_moe_gemm1_xdl_fp8_blockscale_splitk moe_gemm1_xdl_fp8_blockscale_splitk.cpp) list(APPEND gpu_list gfx942 gfx950 gfx1100 gfx1101 gfx1102 gfx1103 gfx1150 gfx1151 gfx1152 gfx1153 gfx1200 gfx1201 gfx11-generic gfx12-generic) set(target 0) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp index fdaef8ec3e..ecc3034bba 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale.cpp @@ -171,7 +171,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceM // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>; #else static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< Row, Col, DsLayout, ELayout, @@ -185,7 +185,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, true, false, MulRoutedWeight, int32_t, A0DataType>; #endif // clang-format on diff --git a/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp new file mode 100644 index 0000000000..ae707e74a2 --- /dev/null +++ b/example/65_gemm_multiply_multiply/moe_gemm1_xdl_fp8_blockscale_splitk.cpp @@ -0,0 +1,539 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +using ::ck::DeviceMem; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using F32 = float; +using I64 = int64_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using Bypass = ck::tensor_layout::BypassLayoutVerification; + +using A0DataType = F8; +using A1DataType = F32; +using B0DataType = F8; +using B1DataType = F32; +using EDataType = F32; +using AccDataType = F32; +using CShuffleDataType = EDataType; +using D2DataType = F32; +using DsDataType = ck::Tuple; + +using A0Layout = Row; +using B0Layout = Col; +using ELayout = Row; +using D0Layout = Row; +using D1Layout = Col; +using D2Layout = ELayout; +using DsLayout = ck::Tuple; + +struct MulABScaleExpertWeight +{ + template + __host__ __device__ constexpr void operator()(E& e, const C& c, const D2& d2) const; + // for real kernel use + template <> + __host__ __device__ constexpr void + operator()(EDataType& e, const EDataType& c, const float& d2) const + { + (void)d2; + e = ck::type_convert(c); + } +}; + +void preShuffleBuffer(const B0DataType* src, B0DataType* dst, int N, int K, int NXdl) +{ + int KPack = 16 / sizeof(B0DataType); + int NLane = NXdl; + int KLane = 64 / NLane; + + int K0 = K / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(I64 n = 0; n < N; ++n) + { + for(I64 k = 0; k < K; ++k) + { + I64 n0 = n / NLane; + I64 n1 = n % NLane; + + I64 k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + I64 k1 = tempk / KPack; + I64 k2 = tempk % KPack; + + I64 outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * static_cast(K) + k]; + } + } +} +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = MulABScaleExpertWeight; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t Nswizzle = false; +static constexpr ck::index_t IsInputGemm = true; // splitk gemm1 goes to gemm2 pipeline. +static constexpr ck::index_t IsSplitK = true; // splitk gemm1 +static constexpr ck::index_t ActOP = 0; // 0: gelu_and_mul, 1: silu_and_mul +static constexpr bool MulRoutedWeight = false; // splitk gemm1 does not do routedWeight. + +#if 1 +static constexpr ck::index_t MPerBlock = 32; +static constexpr ck::index_t NPerBlock = 128; +static constexpr ck::index_t MNPerXDL = 16; +static constexpr ck::index_t MXDLPerWave = MPerBlock / (MNPerXDL * 1); +static constexpr ck::index_t NXDLPerWave = NPerBlock / (MNPerXDL * 4); +static constexpr ck::index_t CShuffleMXDLPerWave = MXDLPerWave; +static constexpr ck::index_t CShuffleNXDLPerWave = NXDLPerWave; +static constexpr ck::index_t BLOCKSIZE = 256; + +static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); +static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); +static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t EVec = 16 / sizeof(EDataType); +static constexpr ck::index_t D0Vec = 1; +static constexpr ck::index_t D1Vec = 1; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale + // clang-format off + < Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + //threadnum, mblock, nblock, kblock + BLOCKSIZE, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, NPerBlock, KPerBlock, + // ak1, bk1 + AK1, BK1, + // mn_perxdl + MNPerXDL, MNPerXDL, + // mn_xdlperwave + MXDLPerWave, NXDLPerWave, + // a,b: loadtranfer cluster, cluster order, srcorder,VECDIM, srcpervec, dstpervec, lds_extra + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + CShuffleMXDLPerWave, CShuffleNXDLPerWave, S<1, 32, 1, 8>, S, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; +#else + +static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< + Row, Col, DsLayout, ELayout, + A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CDEElementOp, GemmSpec, + 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, + MPerBlock, 128, 128, + 16, 16, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 4, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, ActOP, Nswizzle, IsInputGemm, IsSplitK, MulRoutedWeight, int32_t, A0DataType>; +#endif +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = true; +#if 1 + // GEMM shape + ck::index_t N = 4096; + ck::index_t K = 6144; + // ck::index_t N = 128; + // ck::index_t K = 512; + ck::index_t experts = 8; + ck::index_t topk = 2; + // ck::index_t sorted_tile_num = 515; + // ck::index_t valid_tile_num = 512; + // ck::index_t tokens = 208; + // ck::index_t sorted_tile_num = 15; + // ck::index_t valid_tile_num = 13; + // ck::index_t sorted_tile_num = 259; + // ck::index_t valid_tile_num = 256; + // ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 2; + ck::index_t valid_tile_num = 2; + ck::index_t tokens = 32; +#else + // deepseek + ck::index_t N = 2048; + ck::index_t K = 7168; + ck::index_t experts = 256; + ck::index_t topk = 8; + ck::index_t tokens = 4096; + ck::index_t sorted_tile_num = 261; + ck::index_t valid_tile_num = 256; +#endif + ck::index_t KBatch = 6; + if(argc == 1) + { + // use default case + } + else if(argc == 2) + { + KBatch = std::stoi(argv[1]); + } + else if(argc == 4) + { + // use default case + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 7) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + } + else if(argc == 9) + { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: N, K, tokens\n"); + exit(0); + } + + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; + if(tokens * topk > valid_size) + { + printf("err config, tokens * topk > valid_size\n"); + exit(-1); + } + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N * 2; + constexpr ck::index_t NumDTensor = DsDataType::Size(); + constexpr auto StrideDs = std::array{0}; + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_B = (N + Scale_Block_N - 1) / Scale_Block_N * 2; + + Tensor expert_ids(HostTensorDescriptor({sorted_tile_num}, {1})); + Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); + Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); + max_token_id.mData = {valid_size}; + // int eids[] = {0, 0, 1, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 3, 3, 3}; + for(int i = 0; i < sorted_tile_num; i++) + { + expert_ids.mData[i] = i / ck::math::integer_divide_ceil(valid_tile_num, experts); + } + + int token_per_tile = (tokens * topk + valid_tile_num - 1) / valid_tile_num; + int tokenid = 0; + + for(int i = 0; i < sorted_size; i++) + { + int tile_off = i % MPerBlock; + if(tile_off < token_per_tile && tokenid < tokens * topk) + { + sorted_token_ids.mData[i] = (tokenid % tokens) | ((tokenid / tokens) << 24); + tokenid++; + } + else + { + sorted_token_ids.mData[i] = tokens; + } + } + Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); + Tensor a1_t_k(HostTensorDescriptor( + {tokens, (K + Scale_Block_K - 1) / Scale_Block_K}, {Scale_Stride_AM, 1}, Row{})); + Tensor b0_e_n_k( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor b1_e_n_k( + HostTensorDescriptor({experts, + (K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N * 2}, + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN}, + Col{})); + Tensor b0_preshuffled( + HostTensorDescriptor({experts, K, N * 2}, {N * 2 * K, 1, K}, Col{})); + Tensor e_t_n_host_result( + HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{})); + Tensor e_t_n_device_result( + HostTensorDescriptor({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{})); + e_t_n_device_result.SetZero(); + std::cout << "a0_t_k: " << a0_t_k.mDesc << std::endl; + std::cout << "a1_t_k: " << a1_t_k.mDesc << std::endl; + std::cout << "b0_e_n_k: " << b0_e_n_k.mDesc << std::endl; + std::cout << "b1_e_n_k: " << b1_e_n_k.mDesc << std::endl; + std::cout << "e_t_n: " << e_t_n_host_result.mDesc << std::endl; + std::cout << "k_batch:" << KBatch << std::endl; + std::cout << "init_method:" << init_method << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-1.0, 1.0}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 2: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 4: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 6: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + break; + default: + a0_t_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * + sorted_token_ids.mDesc.GetElementSpaceSize()); + DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); + DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); + DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); + DeviceMem a1_device_buf(sizeof(A1DataType) * a1_t_k.mDesc.GetElementSpaceSize()); + DeviceMem b0_device_buf(sizeof(B0DataType) * b0_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); + + sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); + expert_ids_dev.ToDevice(expert_ids.mData.data()); + max_token_id_dev.ToDevice(max_token_id.mData.data()); + a0_device_buf.ToDevice(a0_t_k.mData.data()); + a1_device_buf.ToDevice(a1_t_k.mData.data()); + b1_device_buf.ToDevice(b1_e_n_k.mData.data()); + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + // do GEMM + auto device_op = DeviceOpInstance{}; + + int NPerXdl = device_op.GetPreShuffleParameters(); + + preShuffleBuffer( + b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * 2 * experts, K, NPerXdl); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(sorted_token_ids_dev.GetDeviceBuffer(), + expert_ids_dev.GetDeviceBuffer(), + max_token_id_dev.GetDeviceBuffer(), + a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{nullptr}, + e_device_buf.GetDeviceBuffer(), + tokens, + topk, + sorted_size, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + KBatch, + a_element_op, + b_element_op, + cde_element_op); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + if(time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * tokens * topk * N * 2 * K; + std::size_t num_btype = sizeof(A0DataType) * valid_tile_num * K + + sizeof(B0DataType) * K * N * 2 * experts + + sizeof(EDataType) * valid_tile_num * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s.\n" + << device_op.GetTypeString() << std::endl; + } + + if(do_verification) + { + // use atomic, so need to reinit outputs + e_device_buf.ToDevice(e_t_n_device_result.mData.data()); + invoker.Run(argument, StreamConfig{nullptr, false, 0, 0, 1}); + + Tensor a_t_k({tokens, K}); + Tensor b_e_n_k({experts, K, N * 2}); + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + Tensor c_t_k_n({tokens, topk, N * 2}, {topk * N * 2, N * 2, 1}, Row{}); + + // handle scale before ref. + for(int t = 0; t < tokens; ++t) + { + for(int k = 0; k < K; ++k) + { + a_t_k(t, k) = ck::type_convert(a0_t_k(t, k)) * a1_t_k(t, k / Scale_Block_K); + } + } + + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < K; ++k) + { + for(int n = 0; n < N * 2; ++n) + { + b_e_n_k(e, k, n) = ck::type_convert(b0_e_n_k(e, k, n)) * + b1_e_n_k(e, k / Scale_Block_K, n / Scale_Block_N); + } + } + } + using ReferenceGemmInstance = + ck::tensor_operation::host::ReferenceMoeGemm1BlockScaleSplitK; + auto ref_moe_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_moe_gemm.MakeInvoker(); + + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a_t_k, + b_e_n_k, + c_t_k_n, + PassThrough{}, + PassThrough{}, + PassThrough{}); + + ref_invoker.Run(ref_argument); + for(int m = 0; m < valid_size; ++m) + { + + const int fuse_t = sorted_token_ids.mData[m]; + const int t = fuse_t & 0xffffff; + const int topk_id = (fuse_t & 0xff000000) >> 24; + + if(t >= tokens) + { + continue; + } + for(int n = 0; n < 2 * N; ++n) + { + e_t_n_host_result(t, topk_id, n) = + ck::type_convert(c_t_k_n(t, topk_id, n)); + } + } + + e_device_buf.FromDevice(e_t_n_device_result.mData.data()); + + auto status = + ck::utils::check_err( + e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-1) + ? 0 + : 1; + if(status == 0) + { + printf("Validation Pass.\n"); + } + return status; + } + + return 0; +} diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 35c9d3d788..fb5e3b6456 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -165,7 +165,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, 2, 2, S<1, CShuffleMLane, 1, CShuffleNLane>, S, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>; #else static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< @@ -180,7 +180,7 @@ static constexpr ck::index_t MPerBlock = 64; using DeviceOpInstance = ck::tensor S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 2, 2, S<1, 32, 1, 8>, S<2, 1, 1, 1>, - ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, MulRoutedWeight, int32_t, A0DataType>; + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, 0, false, false, false, MulRoutedWeight, int32_t, A0DataType>; #endif // clang-format on diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp index 59265502e8..a76be40753 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -360,6 +360,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< }); }); + __builtin_amdgcn_sched_barrier(0); // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -550,6 +551,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< }); }); + __builtin_amdgcn_sched_barrier(0); a_scale_thread_copy.Run(a_scale_grid_desc, a_scale_grid_buf, a_scale_thread_desc, @@ -677,6 +679,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< }); block_sync_lds(); + __builtin_amdgcn_sched_barrier(0); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp index feaace8919..df7179efe5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm_blockscale.hpp @@ -74,6 +74,7 @@ template 1) - hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); + // if(arg_.KBatch > 1) + // hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, + // 0, + // arg_.M * arg_.N * sizeof(CDataType) + // * (IsInputGemm && IsSplitK ? 2 : 1), + // stream_config.stream_id_)); }; ave_time = ck::utility::launch_and_time_kernel_with_preprocess( @@ -267,11 +270,12 @@ struct DeviceMoeGemmBlockScale } else { - if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); + // if(arg.KBatch > 1) + // hipGetErrorString(hipMemsetAsync(arg.p_c_grid, + // 0, + // arg.M * arg.N * sizeof(CDataType) * + // (IsInputGemm && IsSplitK ? 2 : 1), + // stream_config.stream_id_)); ave_time = launch_and_time_kernel( stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); @@ -289,8 +293,9 @@ struct DeviceMoeGemmBlockScale constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2; - constexpr auto MemoryDataOp = - IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; + constexpr auto MemoryDataOp = (IsInputGemm && !IsSplitK) + ? InMemoryDataOperationEnum::Set + : InMemoryDataOperationEnum::AtomicAdd; if(has_main_k_block_loop) { @@ -416,8 +421,8 @@ struct DeviceMoeGemmBlockScale static bool IsSupportedArgument(const Argument& arg) { - // only impl kbatch 1 now - if(arg.KBatch > 1) + // only impl kbatch 1 for fp32 + if(arg.KBatch > 1 && !std::is_same_v) { return false; } @@ -441,6 +446,11 @@ struct DeviceMoeGemmBlockScale { return false; } + if(arg.KBatch > 1 && arg.K % (KPerBlock * arg.KBatch) != 0) + { + // Not support Kpadding with KBatch > 1 + return false; + } if(get_warp_size() == 64) { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 0a565bf17e..c556dbec10 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -60,8 +60,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset, p_shared, karg, karg.a_element_op, @@ -101,8 +101,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.ascale_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.bscale_k_split_offset, p_shared, p_shared1, karg, @@ -167,6 +167,7 @@ template {}, Sequence<0>{}), make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - return a_grid_desc_ak0_m_ak1; } } @@ -741,35 +748,41 @@ struct GridwiseMoeGemmBlockScale { if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead / APackedSize; + a_k_split_offset = k_id * karg.KRead / APackedSize; + ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK); } else if constexpr(is_same_v) { - a_k_split_offset = k_id * karg.KRead * karg.StrideA; + a_k_split_offset = k_id * karg.KRead * karg.StrideA; + ascale_k_split_offset = math::integer_divide_floor(a_k_split_offset, ScaleBlockK); } if constexpr(is_same_v) { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; + b_k_split_offset = k_id * karg.KRead * karg.StrideB; + bscale_k_split_offset = math::integer_divide_floor(b_k_split_offset, ScaleBlockK); } else if constexpr(is_same_v) { // KPack * NLane * KLane * K0 * N0 - b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; + b_k_split_offset = k_id * karg.KRead * NLane / BPackedSize; + bscale_k_split_offset = k_id * karg.KRead / ScaleBlockK; } - if(k_id < karg.KBatch - 1) - { - karg.K = karg.KRead; - } - else - { - karg.K = karg.K - karg.KRead * (karg.KBatch - 1); - } + // if(k_id < karg.KBatch - 1) + // { + // karg.K = karg.KRead; + // } + // else + // { + // karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + // } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t ascale_k_split_offset; + index_t bscale_k_split_offset; }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -912,8 +925,8 @@ struct GridwiseMoeGemmBlockScale } using BlockwiseGemmPipe = - remove_cvref_t())>; + IsInputGemm && !IsSplitK > ())>; __device__ static constexpr index_t GetSharedMemoryNumberOfByte() { @@ -1189,9 +1202,9 @@ struct GridwiseMoeGemmBlockScale BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { - ignore = b_element_op; - index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + ignore = b_element_op; + index_t BN0Shuffled = CalculateBN0Shuffled(problem.N * (IsInputGemm && IsSplitK ? 2 : 1)); + index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, @@ -1204,8 +1217,8 @@ struct GridwiseMoeGemmBlockScale const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, - problem.N, - problem.NPadded, + problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideC); const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( @@ -1215,7 +1228,8 @@ struct GridwiseMoeGemmBlockScale math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + make_tuple(math::integer_divide_ceil(problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + ScaleBlockN), math::integer_divide_ceil(problem.K, ScaleBlockK)), make_tuple(math::integer_divide_ceil(problem.K, ScaleBlockK), 1)); @@ -1371,9 +1385,10 @@ struct GridwiseMoeGemmBlockScale decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - + problem.KBatch == 1 + ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock + : problem.KBatch); constexpr index_t ScaleSliceSizeM = MXdlPerWave; constexpr index_t ScaleSliceSizeN = math::integer_divide_ceil(NPerBlock, ScaleBlockN); constexpr index_t ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); @@ -1447,7 +1462,7 @@ struct GridwiseMoeGemmBlockScale constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); - if constexpr(IsInputGemm) + if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( @@ -1606,7 +1621,7 @@ struct GridwiseMoeGemmBlockScale blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( make_tuple(m0, n0, n2 * N4 + n4)); constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion, elementwise + if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise { if constexpr(ActivationOperation == Activation::silu_and_mul) { @@ -1743,8 +1758,12 @@ struct GridwiseMoeGemmBlockScale using EDataType = CDataType; - const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( - problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + const auto ds_grid_desc_m_n = + MakeDsGridDescriptor_M_N(problem.M, + problem.MPadded, + problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), + problem.StrideDs); const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( @@ -1875,7 +1894,8 @@ struct GridwiseMoeGemmBlockScale { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - scatter_offsets(m0) = token_offset * problem.N; + scatter_offsets(m0) = + token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1); }); block_sync_lds(); @@ -1953,8 +1973,8 @@ struct GridwiseMoeGemmBlockScale const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens, problem.MPadded, - problem.N, - problem.NPadded, + problem.N * (IsInputGemm && IsSplitK ? 2 : 1), + problem.NPadded * (IsInputGemm && IsSplitK ? 2 : 1), problem.StrideC); const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor( @@ -2125,8 +2145,10 @@ struct GridwiseMoeGemmBlockScale decltype(c_thread_buf) c_thread_buf_up; const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + problem.KBatch == 1 + ? (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock + : problem.KBatch); // scale constexpr index_t ScaleSliceSizeM = MXdlPerWave; @@ -2202,7 +2224,7 @@ struct GridwiseMoeGemmBlockScale constexpr auto b_scale_thread_slice_copy_step = make_multi_index(0, ScaleSliceSizeK); constexpr auto NumKBlockPerScale = math::integer_divide_ceil(ScaleBlockK, KPerBlock); - if constexpr(IsInputGemm) + if constexpr(IsInputGemm && !IsSplitK) { const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize; const auto b_grid_buf_up = make_dynamic_buffer( @@ -2352,7 +2374,7 @@ struct GridwiseMoeGemmBlockScale blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset( make_tuple(m0, n0, n2 * N4 + n4)); constexpr auto cidx = Number{}; - if constexpr(IsInputGemm) // gu fusion, elementwise + if constexpr(IsInputGemm && !IsSplitK) // gu fusion, elementwise { if constexpr(ActivationOperation == Activation::silu_and_mul) { @@ -2619,7 +2641,8 @@ struct GridwiseMoeGemmBlockScale { token_offset = token_offset * problem.TopK + (fused_token >> 24); } - scatter_offsets(m0) = token_offset * problem.N; + scatter_offsets(m0) = + token_offset * problem.N * (IsInputGemm && IsSplitK ? 2 : 1); }); block_sync_lds(); diff --git a/include/ck_tile/core/tensor/tile_scatter_gather.hpp b/include/ck_tile/core/tensor/tile_scatter_gather.hpp index 7a4da64c4a..2ffaff2973 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather.hpp @@ -218,6 +218,44 @@ struct tile_scatter_gather pre_computed_coords_(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); }); + if constexpr(BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + auto partition_index = get_partition_index(tile_distribution); + + auto use_lane_id_0 = partition_index; + use_lane_id_0[1] = 0; + const auto window_adaptor_thread_coord_tmp_warp = make_tensor_adaptor_coordinate( + tile_distribution.get_ps_ys_to_xs_adaptor(), + container_concat(use_lane_id_0, array{0})); + + BottomTensorIndex bottom_tensor_thread_origin_idx_tmp_warp = + window_origin + window_adaptor_thread_coord_tmp_warp.get_bottom_index(); + bottom_tensor_thread_origin_idx_tmp_warp(HsGatherDim) = 0; + const auto bottom_tensor_thread_coord_tmp_warp = + make_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + bottom_tensor_thread_origin_idx_tmp_warp); + + // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up + // future load/store() calls (might allocate more registers) + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp_warp; + auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp_warp; + + constexpr auto idx_diff_ys = + SFC_Ys::get_step_between(number<0>{}, number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + + pre_computed_warp_coords_(iCoord) = + make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord); + }); + } } CK_TILE_DEVICE static constexpr index_t get_num_of_dimension() { return NDimBottomTensor; } @@ -602,6 +640,135 @@ struct tile_scatter_gather }); } + // TODO: fix with swizzle + template >>> + CK_TILE_DEVICE void async_load_with_offset(index_t offset, + LdsTileWindow_&& lds_tile, + number = {}, + bool_constant = {}, + bool_constant = {}) const + { + using LdsTileWindow = remove_cvref_t; + using LdsDataType = typename LdsTileWindow::DataType; + + using Traits = load_store_traits; + + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + // Precompute invariant values outside loops + const auto window_origin = lds_tile.get_window_origin(); + const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view(); + const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor(); + auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_; + + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1]; + + auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0]; + auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + constexpr auto idx_ys_offset = [&]() { + constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess); + constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate( + StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(), + container_concat(array{0}, + to_array(idx_off_ys))); + return adapter_ys_offset.get_bottom_index(); + }(); + const auto lds_ys_offset = [&]() { + if constexpr(static_move_ys) + { + const auto coord_ys_offset = + make_tensor_coordinate(tensor_descriptor, idx_ys_offset); + return coord_ys_offset.get_offset(); + } + else + return 0; + }(); + + // Use precomputed window origin & tensor descriptor + auto lds_bottom_tensor_thread_idx = + window_origin + window_adaptor_warp_coord.get_bottom_index(); + const auto lds_coord = + make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx); + + // Calculate SMEM address using base pointer + CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr + + lds_coord.get_offset() / Traits::PackedSize + + lds_ys_offset / Traits::PackedSize; + + const auto dram_ys_offset = [&]() { + if constexpr(static_move_ys) + { + const auto coord_ys_offset = make_tensor_coordinate( + this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset); + return coord_ys_offset.get_offset(); + } + else + return 0; + }(); + + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + constexpr auto idx_gather = idx_ys_start[number{}]; + const auto page_offset = page_idx_[idx_gather]; + + auto mixed_bottom_thread_coord = bottom_tensor_thread_coord; + mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset; + + if constexpr(std::is_same_v) + { + this->get_bottom_tensor_view().template async_get_vectorized_elements( + smem, + mixed_bottom_thread_coord, + offset + dram_ys_offset, + bool_constant{}); + } + else + { + this->get_bottom_tensor_view().template async_get_vectorized_elements( + smem, + mixed_bottom_thread_coord, + offset + dram_ys_offset, + valids_[idx_gather], + bool_constant{}); + } + + // Move thread coordinate if not last access + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto forward_step_scatter = generate_tuple( + [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; }, + number{}); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + forward_step_scatter); + + if constexpr(!static_move_ys) + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, + bottom_tensor_thread_coord, + idx_diff_ps_ys); + + if constexpr(!static_move_ys) + move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys); + } + }); + }); + } + template CK_TILE_DEVICE void update(const static_distributed_tensor& dstr_tensor, number = {}, @@ -788,6 +955,15 @@ struct tile_scatter_gather pre_computed_coords_(iCoord)(I1), step_new); }); + if constexpr(BottomTensorView::buffer_view::get_address_space() == + address_space_enum::global) + { + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(), + pre_computed_warp_coords_(iCoord)(I1), + step_new); + }); + } } CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; } @@ -892,6 +1068,11 @@ struct tile_scatter_gather // per-thread coordinate for window adaptor // per-thread coordinate for bottom tensor array, NumCoord> pre_computed_coords_; + std::conditional_t, NumCoord>, + std::byte> + pre_computed_warp_coords_; }; // TODO: use strategy @@ -906,7 +1087,7 @@ make_tile_scatter_gather(const TensorView_& tensor_view, const WindowLengths_& window_lengths, const multi_index& origin, const StaticTileDistribution_& tile_distribution, - const StaticPageIndexArray_& page_idx, + const StaticPageIndexArray_& page_idx, // perbytes number = {}, number = {}) { diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index 368ffa96e2..cc3306f0fc 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -217,6 +217,7 @@ struct MoeFlatmmKernel static constexpr auto I1 = number<1>(); static constexpr auto I2 = number<2>(); static constexpr auto I3 = number<3>(); + static constexpr auto I4 = number<4>(); static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); @@ -241,12 +242,24 @@ struct MoeFlatmmKernel IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock; // MXF4_Pipeline only has the of scale B and granularityK is 32 - static constexpr bool MXFP4_Pipeline = std::is_same_v; - static constexpr int MXFP4N_Pack = 2; - static constexpr int MXFP4K_Pack = 2; + static constexpr bool AQUANT_Pipeline = std::is_same_v || + std::is_same_v || + std::is_same_v; + static constexpr bool BMXFP4_Pipeline = std::is_same_v; - static constexpr int N_Pack = MXFP4_Pipeline ? MXFP4N_Pack : 1; - static constexpr int K_Pack = MXFP4_Pipeline ? MXFP4K_Pack : 1; + static constexpr bool MXF8F6F4MFMA = +#ifdef __gfx950__ + AQUANT_Pipeline && BMXFP4_Pipeline; +#else + false; +#endif + static constexpr int MXFP4M_Pack = 2; + static constexpr int MXFP4N_Pack = 2; + static constexpr int MXFP4K_Pack = 2; + + static constexpr int M_Pack = AQUANT_Pipeline ? MXFP4M_Pack : 1; + static constexpr int N_Pack = BMXFP4_Pipeline ? MXFP4N_Pack : 1; + static constexpr int K_Pack = BMXFP4_Pipeline ? MXFP4K_Pack : 1; static constexpr int WeightPackedSize = numeric_traits::PackedSize; @@ -659,23 +672,95 @@ struct MoeFlatmmKernel } }(); - auto scale_n = kargs.scale_n; - constexpr int GranularityK = decltype(scale_n)::GranularityK; + const auto& scale_a_tensor_view = [&]() { + auto scale_m_desc = kargs.scale_m; + if constexpr(AQUANT_Pipeline) + { + constexpr int AGranularityK = decltype(scale_m_desc)::GranularityK == 0 + ? 1 + : decltype(scale_m_desc)::GranularityK; - index_t scale_k = GranularityK == 0 ? 1 : (kargs.K + GranularityK - 1) / GranularityK; - index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); - index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); + constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); + index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); + // Pack 2x2 e8m0 over M/K dimension into 1 int32_t to trigger dword width load + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_m_packs, scale_k_packs, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_m_packs, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view( + reinterpret_cast(scale_m_desc.ptr), scale_a_desc); + } + else + { + constexpr int AGranularityK = 32; + constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(I0); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I0); + index_t scale_m_packs = kargs.M / (MXFP4M_Pack * MThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * AGranularityK * KThreadPerXdl); + return make_naive_tensor_view( + reinterpret_cast(scale_m_desc.ptr), + make_tuple(scale_m_packs * MThreadPerXdl, scale_k_packs * KThreadPerXdl), + make_tuple(scale_k_packs * KThreadPerXdl, 1), + number<8>{}, + number<1>{}); + } + }(); - using ScaleType = std::conditional_t; + const auto scale_b_flat_view = [&]() { + auto scale_n = kargs.scale_n; + constexpr int BGranularityK = + decltype(scale_n)::GranularityK == 0 ? 1 : decltype(scale_n)::GranularityK; + if constexpr(AQUANT_Pipeline) + { + index_t scale_k = + BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK; + constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(I1); + constexpr int KThreadPerXdl = 64 / BlockGemmShape::WarpTile::at(I1); + index_t scale_n_packs = kargs.N / (MXFP4N_Pack * NThreadPerXdl); + index_t scale_k_packs = kargs.K / (MXFP4K_Pack * BGranularityK * KThreadPerXdl); + const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_n_packs, scale_k_packs, KThreadPerXdl, NThreadPerXdl)); + const auto scale_b_desc = transform_tensor_descriptor( + scale_b_navie_desc, + make_tuple(make_merge_transform(make_tuple(scale_n_packs, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_k_packs, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - const auto scale_b_flat_view = make_naive_tensor_view( - reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, - make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK), - make_tuple(FlatScaleK, 1), - number<8>{}, - number<1>{}); + return make_tensor_view( + reinterpret_cast(scale_n.ptr) + + expert_id * kargs.N * scale_k / 4, + scale_b_desc); + } + else + { + index_t scale_k = + BGranularityK == 0 ? 1 : (kargs.K + BGranularityK - 1) / BGranularityK; + index_t FlatScaleK = scale_k * N_Pack * BlockGemmShape::WarpTile::at(I1); + index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); - return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view, scale_b_flat_view); + using ScaleType = e8m0_t; + + return make_naive_tensor_view( + reinterpret_cast(scale_n.ptr) + expert_id * kargs.N * scale_k, + make_tuple(FlatScaleN - kargs.n_padded_zeros / NPerXdl / N_Pack, FlatScaleK), + make_tuple(FlatScaleK, 1), + number<8>{}, + number<1>{}); + } + }(); + + return make_tuple(a_tensor_view, + b_flat_tensor_view, + c_tensor_view, + scale_a_tensor_view, + scale_b_flat_view); } template @@ -718,7 +803,7 @@ struct MoeFlatmmKernel } }(); - return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3)); + return make_tuple(a_pad_view, views.at(I1), c_pad_view, views.at(I3), views.at(I4)); } template @@ -747,7 +832,7 @@ struct MoeFlatmmKernel } }(); - constexpr bool isNonInterleaveGateUp = !IsGateUp || MXFP4_Pipeline; + constexpr bool isNonInterleaveGateUp = !IsGateUp || BMXFP4_Pipeline; const auto& b_flat_block_window = make_tile_window(b_flat_pad_view, @@ -766,17 +851,40 @@ struct MoeFlatmmKernel output_N_offset}); constexpr int GranularityK = 32; // fixed config for MXF4_Pipeline + auto a_scale_block_window = make_tile_window( + views.at(I3), + make_tuple(number{}, + number{}), + {coord_m / M_Pack, 0}); + constexpr int XDLPerLoadScaleB = - MXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4 + BMXFP4_Pipeline ? 4 : 1; // GranularityK32 / XDL16x16x32_K8 = 4 - auto scale_block_window = - make_tile_window(views.at(I3), - make_tuple(number{}, - number{}), - {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + auto b_scale_block_window = [&]() { + if constexpr(MXF8F6F4MFMA) + { + return make_tile_window( + views.at(I4), + make_tuple(number{}, + number{}), + {coord_n / N_Pack, 0}); + } + else + { + return make_tile_window( + views.at(I4), + make_tuple(number{}, + number{}), + {coord_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + } + }(); - return make_tuple(a_block_window, b_flat_block_window, c_block_window, scale_block_window); + return make_tuple(a_block_window, + b_flat_block_window, + c_block_window, + a_scale_block_window, + b_scale_block_window); } template @@ -831,7 +939,6 @@ struct MoeFlatmmKernel if(coord_m >= max_token_id) return; - static_for<0, DramMRepeat, 1>{}([&](auto m0) { const auto row_idx = coord_m + m0 * (TilePartitioner::MPerBlock / DramMRepeat) + a_coord[I0]; @@ -864,9 +971,10 @@ struct MoeFlatmmKernel const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& scale_block_window = gemm_tile_windows.at(I3); + const auto& a_block_window = gemm_tile_windows.at(I0); + const auto& b_block_window = gemm_tile_windows.at(I1); + const auto& a_scale_block_window = gemm_tile_windows.at(I3); + const auto& b_scale_block_window = gemm_tile_windows.at(I4); auto a_gather_block_tile = ck_tile::make_tile_scatter_gather(a_block_window.get_bottom_tensor_view(), @@ -876,17 +984,32 @@ struct MoeFlatmmKernel a_offsets); // K DRAM tile window for auto c_block_tile = [&] { - if constexpr(MXFP4_Pipeline) + if constexpr(BMXFP4_Pipeline) { - // MXFP4_Pipeline uses gate-up interleave 16 layout for weight + // BMXFP4_Pipeline uses gate-up interleave 16 layout for weight // so don't need extra processing - return FlatmmPipeline{}(a_gather_block_tile, - b_block_window, - scale_block_window, // weight scale with granularityK = 32 - num_loop, - kargs.k_padded_zeros, - smem_ptr_ping, - smem_ptr_pong); + if constexpr(AQUANT_Pipeline) + { + return FlatmmPipeline{}( + a_gather_block_tile, + b_block_window, + a_scale_block_window, // weight scale with granularityK = 32 + b_scale_block_window, // weight scale with granularityK = 32 + num_loop, + smem_ptr_ping, + smem_ptr_pong); + } + else + { + return FlatmmPipeline{}( + a_gather_block_tile, + b_block_window, + b_scale_block_window, // weight scale with granularityK = 32 + num_loop, + kargs.k_padded_zeros, + smem_ptr_ping, + smem_ptr_pong); + } } else { @@ -964,7 +1087,7 @@ struct MoeFlatmmKernel constexpr index_t ScaleMRepeat = MRepeat * kM0 * kM2; statically_indexed_array scale_m_offsets; - if constexpr(!MXFP4_Pipeline) + if constexpr(!BMXFP4_Pipeline) static_for<0, MRepeat, 1>{}([&](auto mIter) { static_for<0, kM0, 1>{}([&](auto m0) { static_for<0, kM2, 1>{}([&](auto m2) { @@ -1059,7 +1182,7 @@ struct MoeFlatmmKernel number<1>{}); auto exp_bias_window = make_tile_window( - permute_tensor_view(exp_bias_view, number<(MXFP4_Pipeline && !IsInputGemm)>{}), + permute_tensor_view(exp_bias_view, number<(BMXFP4_Pipeline && !IsInputGemm)>{}), make_tuple(number{}, number < IsGateUp ? TilePartitioner::NPerBlock / 2 : TilePartitioner::NPerBlock > {}), @@ -1101,7 +1224,7 @@ struct MoeFlatmmKernel ExpBiasBuffer exp_bias_buffer, exp_bias_up_buffer; ExpWeightBuffer exp_weight_buffer; - if constexpr(!MXFP4_Pipeline) + if constexpr(!BMXFP4_Pipeline) { scale_m_window.load(scale_m_buffer); scale_n_buffer = load_tile(scale_n_window); @@ -1233,7 +1356,7 @@ struct MoeFlatmmKernel auto epi_exp_bias_up = epi_tile_idx_slice(exp_bias_up_buffer, epi_m, epi_n); static_for<0, ActVectorSize, 1>{}([&](auto idx) { - if constexpr(!MXFP4_Pipeline) + if constexpr(!BMXFP4_Pipeline) { gate_tensor.get_thread_buffer()[idx] *= epi_scale_m[idx] * epi_scale_n[idx]; @@ -1260,7 +1383,7 @@ struct MoeFlatmmKernel auto epi_exp_bias = epi_tile_idx_slice(exp_bias_buffer, epi_m, epi_n); static_for<0, ActVectorSize, 1>{}([&](auto idx) { - if constexpr(!MXFP4_Pipeline) + if constexpr(!BMXFP4_Pipeline) lds_tile[lds_stage].get_thread_buffer()[idx] *= epi_scale_m[idx] * epi_scale_n[idx]; if constexpr(EnableBias) diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index 74d82b8949..11b978813a 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -1263,4 +1263,1174 @@ struct F16xMXF4FlatmmPipelineAGmemBGmemCRegV1 } }; +template +struct F8xMXF4FlatmmPipelineProblem : FlatmmPipelineProblem +{ + using BlockGemmShape = BlockGemmShape_; + + // using QuantType = BDataType_; + + static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp; + + static constexpr int ScaleGranularityK = 32; + + static constexpr int ContinuousKPerThread = 32; // it's fixed for fp4 + static constexpr int MXdlPack = 2; // it's fixed for fp4 + static constexpr int NXdlPack = 2; // it's fixed for fp4 + static constexpr int KXdlPack = 2; + // static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp * KXdlPack; + static constexpr index_t flatKPerWarp = get_warp_size() * ContinuousKPerThread; +}; + +template +struct F8xMXF4FlatmmPipelineAGmemBGmemCRegV1 + : FlatmmPipelineAGmemBGmemCRegV1 +{ + using Underlying = FlatmmPipelineAGmemBGmemCRegV1; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; // TileFlatmmShape + + using ComputeType = ADataType; + static_assert(sizeof(ADataType) >= sizeof(BDataType)); + + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using BlockFlatmm = + remove_cvref_t())>; + + static constexpr auto config = + BlockFlatmm::BlockPolicy::template GetWarpGemmMWarpNWarp(); + + using WG = remove_cvref_t())>; + + static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2 + static constexpr index_t DsReadPreload = 4; // default 4 for MXFP4 (MXdlPack * KXdlPack) + + static constexpr index_t BlockSize = Problem::kBlockSize; + static constexpr index_t WaveSize = get_warp_size(); + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + static constexpr index_t flatKPerWarp = Problem::flatKPerWarp; + static constexpr index_t flatNPerWarp = Problem::flatNPerWarp; + + static constexpr index_t GetVectorSizeA() { return 32; } /* fixed for fp4 shuffle layout*/ + static constexpr index_t GetVectorSizeB() { return 32; } /* fixed for fp4 shuffle layout*/ + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + // static constexpr index_t kLdsAlignmentInBytes = 16; + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel; + + static constexpr auto I0 = number<0>(); + static constexpr auto I1 = number<1>(); + static constexpr auto I2 = number<2>(); + static constexpr auto idxM = I0; + static constexpr auto idxN = I1; + static constexpr auto idxK = I2; + using BlockTile = remove_cvref_t; + using BlockWarps = remove_cvref_t; + using WarpTile = remove_cvref_t; + + static constexpr index_t MWarp = config.template at<1>(); + static constexpr index_t NWarp = config.template at<2>(); + + static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM); + static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); + static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK; + + static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp; + static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp; + + static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp; + static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp; + + static constexpr index_t APackedSize = numeric_traits::PackedSize; + static constexpr index_t BPackedSize = numeric_traits::PackedSize; + + static constexpr index_t MXdlPack = Problem::MXdlPack; + static constexpr index_t NXdlPack = Problem::NXdlPack; + static constexpr index_t KXdlPack = Problem::KXdlPack; + static constexpr index_t ScaleGranularityK = Problem::ScaleGranularityK; + + static constexpr index_t AK1 = + Problem::VectorLoadSize / sizeof(ADataType) * APackedSize; // 16 / 1 = 16 + static constexpr index_t BK1 = + Problem::VectorLoadSize / sizeof(BDataType) * BPackedSize; // 16 / 1 * 2 = 32 + + static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload) + ? DsReadPreload + : MIterPerWarp * KIterPerWarp; + + static constexpr bool HasHotLoop = Problem::HasHotLoop; + static constexpr auto TailNum = Problem::TailNum; + + static constexpr index_t mfma_per_wg = 1; // 950 only + + static constexpr index_t dsread_per_wg = + WG::kM * WG::kK / AK1 / WaveSize; // 16 * 128 / 16 / 64 = 2 + static_assert((WG::kM * WG::kK) % (AK1 * WaveSize) == 0); // 16 * 128 % 16 * 64 + + static constexpr index_t dsread_num_perK = dsread_per_wg * MIterPerWarp; + static constexpr index_t dswrite_num_perK = dsread_num_perK / NWarp; + static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp; + static constexpr index_t Aload_num_perK = dswrite_num_perK; + static constexpr index_t Aload_rep = dswrite_rep; + + static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / BK1 / WaveSize; + static constexpr index_t Bload_num = Bload_num_perK * KIterPerWarp; + static constexpr index_t ScaleBload_num = + kNPerBlock * kKPerBlock / NWarp / ScaleGranularityK / NXdlPack / KXdlPack / WaveSize; + static constexpr index_t ScaleAload_num = + kMPerBlock * kKPerBlock / MWarp / ScaleGranularityK / MXdlPack / KXdlPack / WaveSize; + + // static constexpr index_t KPerScaleLoad = KIterPerWarp / ScaleBload_num; + static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2; + static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter; + + static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg; + static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp; + static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp; + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + + static constexpr auto BMemNTType = Problem::BMemNTType; + static constexpr bool BPreShufflePermute = Problem::BPreShufflePermute; + + CK_TILE_HOST_DEVICE static constexpr auto + SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM) + { + // Init inst order + index_t max_data_inst = dsread_perM > load_perM + ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM) + : (load_perM > dswrite_perM ? load_perM : dswrite_perM); + index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM; + index_t round_data_inst = (sum_data_inst + mfma_perM_perK - 1) / mfma_perM_perK; + + index_t inst_order[NIterPerWarp * 10]; + _Pragma("unroll") for(int idx = 0; idx < NIterPerWarp * 10; idx++) { inst_order[idx] = 0; } + + index_t index = 0; + _Pragma("unroll") for(int j = 0; j < max_data_inst; j++) + { + if(dswrite_perM > j) + { + inst_order[index] = 1; + index++; + } + if(load_perM > j) + { + inst_order[index] = 2; + index++; + } + if(dsread_perM > j) + { + inst_order[index] = 3; + index++; + } + } + + // Schedule IGLP + _Pragma("unroll") for(int j = 0; j < mfma_perM_perK; j++) + { + index_t inst_idx = 0; + if(j == 0) + ; + else if(j == 1) + inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2; + else if(j == 2) + inst_idx = mfma_perM_perK - 1; + else + inst_idx = mfma_perM_perK - j; + + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + + _Pragma("unroll") for(int r = 0; r < round_data_inst; r++) + { + if(r % 2 == 0) + { + if(inst_order[inst_idx + r * mfma_perM_perK] == 1) + { + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[inst_idx + r * mfma_perM_perK] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + else + { + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1) + { + // __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2) + { + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + } + if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3) + { + __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read + } + } + } + } + } + + CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler() + { + // Keypoint of pipeline optimize is workload balance in time + // instruction schedule example(128X256X256, 1X4, 16X16X128): + // Iter MNK MFMA ds_read ds_write A_load b_load + // -1 M6N0: 57 - 8 - - + // -1 M6N1: 58 1 - - - + // -1 M6N2: 59 - - 7 - + // -1 M6N3: 60 2 - - - + // -1 M7N0: 61 - - - - + // -1 M7N1: 62 3 - - - + // -1 M7N2: 63 - - 8 - + // -1 M7N3: 64 4 - - - + // 0 M0N0K0: 1 - - - 1 + // 0 M0N1: 2 5 - - - + // 0 M0N2: 3 - - - 2 + // 0 M0N3: 4 6 - - - + // 0 M1N0: 5 - - - 3 + // 0 M1N1: 6 7 - - - + // 0 M1N2: 7 - - - 4 + // 0 M1N3: 8 8 - - - + // 0 M2N0: 9 - - - 5 + // 0 M2N1: 10 9 - - - + // 0 M2N2: 11 - - - 6 + // 0 M2N3: 12 10 - - - + // 0 M3N0: 13 - 1 - 7 + // 0 M3N1: 14 11 - - - + // 0 M3N2: 15 - - - 8 + // 0 M3N3: 16 12 - - - + // 0 M4N0: 17 - 2 - - + // 0 M4N1: 18 13 - - - + // 0 M4N2: 19 - - 1 - + // 0 M4N3: 20 14 - - - + // 0 M5N0: 21 - 3 - - + // 0 M5N1: 22 15 - - - + // 0 M5N2: 23 - - 2 - + // 0 M5N3: 24 16 - - - + // 0 M6N0: 25 - 4 - - + // 0 M6N1: 26 17 - - - + // 0 M6N2: 27 - - 3 - + // 0 M6N3: 28 18 - - - + // 0 M7N0: 29 - - - - + // 0 M7N1: 30 19 - - - + // 0 M7N2: 31 - - 4 - + // 0 M7N3: 32 20 - - - + // 0 M0N0K1: 33 - - - 9 + // 0 M0N1: 34 21 - - - + // 0 M0N2: 35 - - - 10 + // 0 M0N3: 36 22 - - - + // 0 M1N0: 37 - - - 11 + // 0 M1N1: 38 23 - - - + // 0 M1N2: 39 - - - 12 + // 0 M1N3: 40 24 - - - + // 0 M2N0: 41 - - - 13 + // 0 M2N1: 42 25 - - - + // 0 M2N2: 43 - - - 14 + // 0 M2N3: 44 26 - - - + // 0 M3N0: 45 - 5 - 15 + // 0 M3N1: 46 27 - - - + // 0 M3N2: 47 - - - 16 + // 0 M3N3: 48 28 - - - + // 0 M4N0: 49 - 6 - - + // 0 M4N1: 50 29 - - - + // 0 M4N2: 51 - - 5 - + // 0 M4N3: 52 30 - - - + // 0 M5N0: 53 - 7 - - + // 0 M5N1: 54 31 - - - + // 0 M5N2: 55 - - 6 - + // 0 M5N3: 56 32 - - - + // 0 M6N0: 57 - 8 - - + // 0 M6N1: 58 1 - - - + // 0 M6N2: 59 - - 7 - + // 0 M6N3: 60 2 - - - + // 0 M7N0: 61 - - - - + // 0 M7N1: 62 3 - - - + // 0 M7N2: 63 - - 8 - + // 0 M7N3: 64 4 - - - + + _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; + + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep + : 0) + + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep + : 0); + } + else + { + load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 + ? Aload_rep + : 0; + } + // if((kIter % KPerScaleLoad == 0) && (mIter == 0)) + // { + // load_perM = load_perM + 1; + // } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + // Add Aload when Aload data > needed + if(Aload_num_perK == 0) + __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler() + { + _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + dsread_perM = dsread_per_wg; + + // Calculate ds_write number per M + if(mIter == 0) + { + dswrite_perM = + (dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0 + ? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep + : 0; + } + else if(mIter >= MIterPerWarp - DsWritePreIssue + 1) + { + dswrite_perM = 0; + } + else + { + dswrite_perM = (dswrite_num_perK - + (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0 + ? dswrite_rep + : 0; + } + // Add ds write when ds write data > needed + if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter)) + { + if(mIter == MIterPerWarp - 1 - dswrite_mIter) + dswrite_perM = 1; + } + + // Calculate buffer_load number per M + if(mIter < HalfMIter) + { + load_perM = + ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep + : 0); + } + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler() + { + _Pragma("unroll") for(int kIter = 0; kIter < KIterPerWarp; kIter++) + { + _Pragma("unroll") for(int mIter = 0; mIter < MIterPerWarp; mIter++) + { + index_t dsread_perM = 0; + index_t dswrite_perM = 0; + index_t load_perM = 0; + + // Calculate ds_read number per M + if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload)) + dsread_perM = dsread_per_wg; + + SchedulerPerM(dsread_perM, dswrite_perM, load_perM); + } + } + // __builtin_amdgcn_sched_barrier(0); + } + + CK_TILE_HOST_DEVICE static constexpr auto GetADramTileDistribution() + { + return PipelinePolicy::template MakeADramTileDistribution(); + } + + template + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem_ping, + void* __restrict__ p_smem_pong) const + { +#ifndef __gfx950__ + static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now."); +#endif + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // constexpr auto MIter_2nd_last = max(0, MIterPerWarp - 2); + static_assert(NWarp == 4); + + using CWarpDstr = typename WG::CWarpDstr; + using CWarpTensor = typename WG::CWarpTensor; + + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + auto a_dram_window = replace_bottom_tensor_view( + PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor( + // PipelinePolicy::template Make_F8AAsyncLoadDramDescriptor( + a_copy_dram_window_tmp.get_bottom_tensor_view()), + a_copy_dram_window_tmp); + + __builtin_amdgcn_sched_barrier(0); + + // A tile in LDS + ADataType* p_a_lds_ping = static_cast(p_smem_ping); + ADataType* p_a_lds_pong = static_cast(p_smem_pong); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor(); + // PipelinePolicy::template MakeF8_ReadALdsBlockDescriptor(); + + constexpr auto a_load_lds_block_desc = + PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor(); + // PipelinePolicy::template MakeF8_WriteALdsBlockDescriptor(); + + auto a_lds_block_ping_load = + make_tensor_view(p_a_lds_ping, a_load_lds_block_desc); + auto a_lds_block_pong_load = + make_tensor_view(p_a_lds_pong, a_load_lds_block_desc); + + auto a_lds_block_ping = + make_tensor_view(p_a_lds_ping, a_lds_block_desc); + auto a_lds_block_pong = + make_tensor_view(p_a_lds_pong, a_lds_block_desc); + + auto a_store_lds_window_ping = make_tile_window( + a_lds_block_ping_load, make_tuple(number{}, number{}), {0, 0}); + auto a_store_lds_window_pong = make_tile_window( + a_lds_block_pong_load, make_tuple(number{}, number{}), {0, 0}); + + // ping-pong window for A LDS + auto a_warp_window_ping = + make_tile_window(a_lds_block_ping, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + auto a_warp_window_pong = + make_tile_window(a_lds_block_pong, + make_tuple(number{}, number{}), + {0, 0}, + PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + // Acc register tile + auto c_block_tile = block_flatmm.MakeCBlockTile(); + + // B flat DRAM window for load + + // pingpong buffer for B + auto b_flat_dram_windows = generate_tuple( + [&](auto nIter) { + constexpr auto packed_n_idx = nIter / number{}; + constexpr auto packed_n_rank = nIter % number{}; + auto window_i = make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution()); + move_tile_window( + window_i, + {number{}, + number<0>{}}); + return window_i; + }, + number{}); + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_tensor_ping, b_warp_tensor_pong; + + // pingpong buffer for Scale A and Scale B + auto scale_a_dram_window = make_tile_window( + scale_a_window.get_bottom_tensor_view(), + make_tuple(number{}, number<64 / WG::kM>{}), + scale_a_window.get_window_origin(), + PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution()); + + auto scale_b_dram_window = make_tile_window( + scale_b_window.get_bottom_tensor_view(), + make_tuple(number{}, number<64 / WG::kN>{}), + scale_b_window.get_window_origin(), + PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution()); + + // ping pong buffer for scale A + statically_indexed_array< + statically_indexed_array, + MIterPerWarp / MXdlPack> + scale_a_dram_windows; + statically_indexed_array, + MIterPerWarp / MXdlPack> + scale_a_tile_tensor_ping; + statically_indexed_array, + MIterPerWarp / MXdlPack> + scale_a_tile_tensor_pong; + + // ping pong buffer for scale B + statically_indexed_array< + statically_indexed_array, + NIterPerWarp / NXdlPack> + scale_b_dram_windows; + statically_indexed_array, + NIterPerWarp / NXdlPack> + scale_b_tile_tensor_ping; + statically_indexed_array, + NIterPerWarp / NXdlPack> + scale_b_tile_tensor_pong; + + auto async_load_tile_ = [](auto lds, auto dram) { + async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{}); + }; + + // HEAD + // Prefetch A0 + async_load_tile_(a_store_lds_window_ping, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); + + // prefetch B + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + }); + // move B window to next flat K + move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter}); + }); + + // prefetch Scale A + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + // move Scale A window to next K + move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + // prefetch Scale B + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + // move Scale B window to next K + move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + __builtin_amdgcn_sched_barrier(0); + + // Prefetch A1 + if constexpr(HasHotLoop || TailNum == TailNumber::Even) + { + async_load_tile_(a_store_lds_window_pong, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); + } + // initialize C + clear_tile(c_block_tile); + + statically_indexed_array a_warp_tensor; + + // preload A00,A10... from lds + s_waitcnt_barrier(); + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_ping, tuple, number>{}); + }); + __builtin_amdgcn_sched_barrier(0); + + // MAIN LOOP + auto main_body_implx2 = [&]() mutable { + // prefetch B(2i+1) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); + }); + }); + + // prefetch Scale A and Scale B (2i+1) + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + + // GEMM 2i + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } + }); + }); + }); + }); + }); + // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + + // Prefetch A(2i+2) + async_load_tile_(a_store_lds_window_ping, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); + + // move B window to next flat K + move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + // preload A(2i+1) + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_pong, tuple, number>{}); + }); + HotLoopScheduler(); + + ////////////////////////////// Next K ////////////////////////////// + + // prefetch B(2i+2) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), number{}); + if constexpr(kIter == KIterPerWarp - 1) + move_tile_window(b_flat_dram_windows(nIter), + {0, BlockGemmShape::flatKPerBlock}); + }); + }); + + // prefetch Scale A and Scale B (2i+2) + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + + // GEMM 2i+1 + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); + } + }); + }); + }); + }); + }); + // barrier as ds_load A(2i + 1) and buffer_load_lds A(2i + 2) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + + // Prefetch A(2i+3) + async_load_tile_(a_store_lds_window_pong, a_dram_window); + move_tile_window(a_dram_window, {0, kKPerBlock}); + // move B window to next flat K + move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + move_tile_window(scale_b_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); + + // preload A(2i+2) + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_ping, tuple, number>{}); + }); + HotLoopScheduler(); + }; + + if constexpr(HasHotLoop) + { + index_t iCounter = (num_loop - 1) / 2; + while(iCounter > 0) + { + main_body_implx2(); + iCounter--; + } + } + + // TAIL + if constexpr(TailNum == TailNumber::Even) + { + // prefetch B(loopK) + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( + b_flat_dram_windows(nIter), + make_tuple(number<0>{}, number{})); + }); + }); + + // prefetch Scale A and Scale B (2i+1) + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; + move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), + {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = + load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + }); + }); + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; + move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), + {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); + + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = + load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + }); + }); + + // GEMM loopK-1 + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}.template + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } + }); + }); + }); + }); + }); + // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + + // preload A(2i+1) + static_for<0, m_preload, 1>{}([&](auto loadIter) { + constexpr auto mIter = loadIter % MXdlPack; + constexpr auto kIter = loadIter / MXdlPack; + a_warp_tensor(loadIter) = load_tile_with_offset( + a_warp_window_pong, tuple, number>{}); + }); + + // Last2ndHotLoopScheduler(); + + // GEMM loopK + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences( + sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}.template + operator()( + // operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_pong(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], // scale A + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); // scale B + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + + (kIter_pack * KXdlPack + ikxdl) * 2 + + (mIter_pack * MXdlPack + imxdl) / 2 * 4 + + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_pong, + tuple, number>{}); + } + }); + }); + }); + }); + }); + // LastHotLoopScheduler(); + } + else if constexpr(TailNum == TailNumber::Odd) + { + // GEMM loopK + static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { + static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { + static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { + static_for<0, MXdlPack, 1>{}([&](auto imxdl) { + constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; + static_for<0, NXdlPack, 1>{}([&](auto inxdl) { + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; + + // read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + c_warp_tensor.get_thread_buffer() = + c_block_tile.get_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // warp GEMM + WG{}.template + // operator()( + operator()( + c_warp_tensor, + a_warp_tensor(number{}), + b_warp_tensor_ping(nIter_pack * number{} + inxdl)( + kIter_pack * number{} + ikxdl), + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) + .get_thread_buffer()[0], + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) + .get_thread_buffer()[0]); + + // write C warp tensor into C block tensor + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, + c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + // preload next A from lds + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; + if constexpr(addr < (KIterPerWarp * MIterPerWarp) && + (nIter_pack == NIterPerWarp / NXdlPack - 1)) + { + constexpr auto AmIter = addr % 2 + addr / 4 * 2; + constexpr auto AkIter = addr / 2 % 2; + a_warp_tensor(number{}) = load_tile_with_offset( + a_warp_window_ping, + tuple, number>{}); + } + }); + }); + }); + }); + }); + // barrier as ds_load A(2i) and buffer_load_lds A(2i + 1) finished + s_waitcnt< // vmcnt + Bload_num + ScaleAload_num + ScaleBload_num>(); + block_sync_lds(); + } + else + { + static_assert(false, "Wrong TailNum"); + } + return c_block_tile; + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp index ea67d80e37..c773cbf736 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mixed_prec_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp @@ -156,7 +156,7 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy constexpr int K_Lane = 64 / TileShape::WarpTile::at(I1); // 4 - constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 8 + constexpr int K2 = TileShape::WarpTile::at(I2) / K_Lane; // 128 / 4 = 32 constexpr int XDL_PerThreadK = KBPerLoad / K2; // 4 constexpr int K0 = K_Lane; // 4 @@ -236,4 +236,513 @@ struct F16xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } }; +struct F8xMXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy +{ + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + static constexpr auto I2 = number<2>{}; + + static constexpr index_t kDramLoadPackBytes = 128; + + static constexpr int MXdlPack = 2; + static constexpr int NXdlPack = 2; + static constexpr int KXdlPack = 2; + + template + static inline constexpr auto wg_attr_num_access = WGAttrNumAccessEnum::Single; + // std::is_same_v, pk_fp4_t> + // ? WGAttrNumAccessEnum::Single + // : WGAttrNumAccessEnum::Double; + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockFlatmm() + { + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + + using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher< // + ADataType, + BDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC, + false, + false, + wg_attr_num_access>; + using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy< // + ADataType, + BDataType, + typename Problem::CDataType, + BlockWarps, + WarpGemm>; + return BlockFlatmmASmemBSmemCRegV1{}; + } + + template + CK_TILE_DEVICE static constexpr auto + MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view) + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + static_assert(MPerXdl == 16 && NPerXdl == 16); + static_assert(std::is_same_v); + + const auto& naive_desc = naive_view.get_tensor_descriptor(); + constexpr auto ndims = remove_cvref_t::get_num_of_dimension(); + static_assert(ndims == 2, "only support 2D tensor"); + const auto rows = naive_desc.get_length(number<0>{}); + const auto cols = naive_desc.get_length(number<1>{}); + + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + const index_t K0 = cols / (K1 * K2); + const auto col_lens = make_tuple(K0, number{}, number{}); + + constexpr index_t M1 = 4; // so that we can use imm offset to load lds + const index_t M0 = rows / M1; + const auto row_lens = make_tuple(M0, number{}); + + const auto desc_0 = + make_naive_tensor_descriptor_packed(container_concat(row_lens, col_lens)); + const auto desc_1 = transform_tensor_descriptor( + desc_0, + make_tuple(make_pass_through_transform(M0), + make_xor_transform(make_tuple(number{}, number{})), + make_pass_through_transform(K0), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); + const auto desc = transform_tensor_descriptor( // + desc_1, + make_tuple(make_merge_transform_v3_division_mod(row_lens), + make_merge_transform_v3_division_mod(col_lens)), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + // printf("A async load dram desc %d x %d: \n", desc.get_length(I0), desc.get_length(I1)); + + return tensor_view, + TensorView::DstInMemOp>{naive_view.buf_, desc}; + } + + template + CK_TILE_DEVICE static constexpr auto + Make_F8AAsyncLoadDramDescriptor(const TensorView& naive_view) + { + constexpr int DynamicTileOffsetFlag = 0; + + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + + static_assert(MPerXdl == 16 && NPerXdl == 16); + + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr int ContiguousThreadsCntInDS_READ_16B = 4; + + // implement swizzle pattern on global side + // because we can't adjust the ds_write pattern of BUFFER_LOAD_LDS. + auto swizzle_a_dram_view_1 = transform_tensor_view( + naive_view, + make_tuple( + // M-dim is not affected by swizzle pattern + make_unmerge_transform( + make_tuple(number{}, number{})), + // K-dim is the swizzle dimension + make_unmerge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{})); + + auto swizzle_a_dram_view_2 = transform_tensor_view( + swizzle_a_dram_view_1, + make_tuple(make_pass_through_transform(number{}), + make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1, 3>{}, sequence<2>{}, sequence<4>{})); + + return transform_tensor_view( + swizzle_a_dram_view_2, + make_tuple( + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 1>{}, sequence<2, 3, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } + + template + CK_TILE_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + static_assert(std::is_same_v); + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t APackedSize = numeric_traits::PackedSize; + + constexpr index_t K2 = MPerBlock == 16 + ? GetSmemPackA() * APackedSize / 4 + : GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + + constexpr index_t M2 = get_warp_size() / K1; // 8 + constexpr index_t M1 = BlockSize / get_warp_size(); // 4 + constexpr index_t M0 = MPerBlock / (M2 * M1); + static_assert(M0 * M1 * M2 == MPerBlock, "M0, M1, M2 must cover whole MPerBlock!"); + static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); + + return make_static_tile_distribution( + tile_distribution_encoding< // + sequence<1>, + tuple, sequence>, // ?,4,8 1,8,32 or 2,8,16 + tuple, sequence<1, 2>>, // M1 M2,K1 + tuple, sequence<2, 1>>, + sequence<1, 2, 2>, // M0,K0,K2 + sequence<0, 0, 2>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() + { + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + static_assert(MPerXdl == 16 && NPerXdl == 16); + static_assert(std::is_same_v); + + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t APackedSize = numeric_traits::PackedSize; + constexpr index_t K2 = GetSmemPackA() * APackedSize; // f4=32; f8=16 + constexpr index_t K1 = kDramLoadPackBytes * APackedSize / K2; // 8 + constexpr index_t K0 = KPerBlock / (K1 * K2); // KPerBlock/256 + static_assert(K0 * K1 * K2 == KPerBlock, "K0, K1, K2 must cover whole KPerBlock!"); + + constexpr index_t M3 = 4; // so that we can use imm offset to load lds + constexpr index_t M2 = get_warp_size() / K1 / M3; // 2 + constexpr index_t M1 = MPerXdl / (M2 * M3); // 2 + constexpr index_t M0 = MPerBlock / (M1 * M2 * M3); // MPerBlock/16 + static_assert(M0 * M1 * M2 * M3 == MPerBlock, "M0, M1, M2, M3 must cover whole MPerBlock!"); + + constexpr index_t Pad = 4 * K2; // 4 * 16 + // constexpr index_t Pad = 0; // 4 * 16 + + // TODO: fix lds_a swizzle + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{}))), + make_tuple(sequence<0, 1, 3, 4>{}, sequence<2, 5, 6>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // return a_lds_block_desc_permuted; + return a_lds_block_desc; + } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeF8_ReadALdsBlockDescriptor() + { + constexpr index_t MPerXdl = Problem::BlockGemmShape::WarpTile::at(I0); + constexpr index_t NPerXdl = Problem::BlockGemmShape::WarpTile::at(I1); + + static_assert(MPerXdl == 16 && NPerXdl == 16); + + /*reduce transform layers,compare with old ck*/ + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr int ContiguousThreadsCntInDS_READ_16B = 4; + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1>{}, sequence<0, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeF8_WriteALdsBlockDescriptor() + { + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetSmemPackA(); + return make_naive_tensor_descriptor(make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number{}, + number<1>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + static_assert(TileShape::WarpTile::at(I1) == 16, "requires XDL_N == 16"); + static_assert(TileShape::BlockWarps::at(I0) == 1, "requires Wave_M == 1"); + + constexpr int M_warps = TileShape::BlockWarps::at(number<0>{}); + constexpr int N_warps = TileShape::BlockWarps::at(number<1>{}); + constexpr int M_Lane = TileShape::WarpTile::at(I0); // 16 + + constexpr int K_Lane = 64 / M_Lane; // 4 + + constexpr int K_Thread = TileShape::WarpTile::at(I2) / K_Lane; // 32 + // constexpr index_t num_access_v = static_cast(wg_attr_num_access); + constexpr index_t num_access_v = 2; + constexpr int K1 = K_Thread / num_access_v; // 16 + + return make_static_tile_distribution( + std::conditional_t< + num_access_v == 1, + tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<2>, + sequence<1>>, + tile_distribution_encoding< // + sequence, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 2>>, + sequence<2, 2>, + sequence<0, 2>>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + static_assert(TileShape::WarpTile::at(I1) == 16, "only for XDL_N == 16"); + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t K1 = WaveSize; // threads cnt in K dim + constexpr index_t KWavePerBlk = 1; + constexpr index_t K0 = KWavePerBlk; + + constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp + + constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp; + constexpr index_t kKPerThread = 32; + constexpr index_t num_access_v = static_cast(wg_attr_num_access); + constexpr index_t K2 = kKPerThread / num_access_v; + + return make_static_tile_distribution( + std::conditional_t< // + num_access_v == 1, + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 1 64 32 + tuple, sequence<2>>, + tuple, sequence<1>>, + sequence<2>, + sequence<2>>, + tile_distribution_encoding< // + sequence, + tuple, // 4 2 + sequence>, // 2 1 64 16 + tuple, sequence<2>>, + tuple, sequence<2>>, + sequence<2, 2>, + sequence<0, 3>>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t kMPerBlock = TileShape::BlockTile::at(I0); + + constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); + constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); + + static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); + + constexpr index_t M_Lanes = TileShape::WarpTile::at(I0); + constexpr index_t K_Lanes = 64 / M_Lanes; + + // Y dimension (M) decomposition + constexpr index_t Y2 = M_Lanes; + constexpr index_t Y1 = M_Warps; + constexpr index_t Y0 = kMPerBlock / (MXdlPack * Y1 * Y2); + + // X dimension (K) decomposition + constexpr index_t X0 = K_Lanes; + constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load + + return make_static_tile_distribution( + tile_distribution_encoding, // repeat N_warps + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t WaveSize = get_warp_size(); + constexpr index_t WaveNum = BlockSize / WaveSize; + + constexpr index_t kNPerBlock = TileShape::BlockTile::at(I1); + + constexpr index_t M_Warps = TileShape::BlockWarps::at(I0); + constexpr index_t N_Warps = TileShape::BlockWarps::at(I1); + + static_assert(WaveNum == M_Warps * N_Warps, "Block warps do not match block size"); + + constexpr index_t N_Lanes = TileShape::WarpTile::at(I1); + constexpr index_t K_Lanes = 64 / N_Lanes; + + // Y dimension (M) decomposition + constexpr index_t Y2 = N_Lanes; + constexpr index_t Y1 = N_Warps; + constexpr index_t Y0 = kNPerBlock / (NXdlPack * Y1 * Y2); + + // X dimension (K) decomposition + constexpr index_t X0 = K_Lanes; + constexpr index_t X1 = 1; // packed 2x2 E8M0 data into 1 int32_t for load + + return make_static_tile_distribution( + tile_distribution_encoding, // ? + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t M_Warp = TileShape::BlockWarps::at(number<0>{}); + constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I0); + constexpr index_t M_Lane = TileShape::WarpTile::at(I0); + constexpr index_t N_Wrap = TileShape::BlockWarps::at(number<1>{}); + constexpr index_t MWavePerBlk = M_Warp; + + return make_static_tile_distribution( + tile_distribution_encoding, // ? + tuple, // second direction + sequence>, // first direction + tuple, sequence<2, 1>>, // which direction + tuple, sequence<0, 1>>, // which index + // + sequence<2>, + sequence<1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution() + { + using TileShape = typename Problem::BlockGemmShape; + + constexpr index_t N_Warp = TileShape::BlockWarps::at(number<1>{}); + constexpr index_t K_Lane = 64 / TileShape::WarpTile::at(I1); + constexpr index_t N_Lane = TileShape::WarpTile::at(I1); + constexpr index_t M_Wrap = TileShape::BlockWarps::at(number<0>{}); + constexpr index_t NWavePerBlk = N_Warp; + + return make_static_tile_distribution( + tile_distribution_encoding, // ? + tuple, // second direction + sequence>, // first direction + tuple, sequence<2, 1>>, // which direction + tuple, sequence<0, 1>>, // which index + // + sequence<2>, + sequence<1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() + { + using ADataType = remove_cvref_t; + constexpr index_t APackedSize = numeric_traits::PackedSize; + return sizeof(ADataType) * + MakeMXFP4_ALdsBlockDescriptor().get_element_space_size() / APackedSize; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return GetSmemSizeA(); + } +}; + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 44a09423ee..c0fbf8e5d3 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -309,25 +309,6 @@ using WarpGemmMfma_f32_16x16x64_bf8_bf8 = WarpGemmImpl using WarpGemmMfma_f32_16x16x128_f8f6f4 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; -template -using WarpGemmMfma_f32_16x16x128_fp8_fp8 = WarpGemmImpl< // - WarpGemmAttributeMfma, - AttrNumAccess>>; - -template -using WarpGemmMfma_f32_16x16x128_fp8_bf8 = WarpGemmImpl< // - WarpGemmAttributeMfma, - AttrNumAccess>>; - -template -using WarpGemmMfma_f32_16x16x128_bf8_fp8 = WarpGemmImpl< // - WarpGemmAttributeMfma, - AttrNumAccess>>; - -template -using WarpGemmMfma_f32_16x16x128_bf8_bf8 = WarpGemmImpl< // - WarpGemmAttributeMfma, - AttrNumAccess>>; template using WarpGemmMfma_f32_16x16x128_fp8_fp8_CTransposed = diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp new file mode 100644 index 0000000000..9d9b8a62f5 --- /dev/null +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm1_blockscale_splitk.hpp @@ -0,0 +1,232 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/device_base.hpp" +#include "ck/library/utility/host_tensor.hpp" + +namespace ck { +namespace tensor_operation { +namespace host { + +template +struct ReferenceMoeGemm1BlockScaleSplitK : public device::BaseOperator +{ + // Argument + struct Argument : public device::BaseArgument + { + Argument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& b_e_n_k, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + : sorted_token_ids_{sorted_token_ids}, + expert_ids_{expert_ids}, + max_token_id_{max_token_id}, + sorted_tile_size_{sorted_tile_size}, + a_t_k_{a_t_k}, + b_e_n_k_{b_e_n_k}, + c_t_k_n_{c_t_k_n}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + c_element_op_{c_element_op} + { + } + + const Tensor& sorted_token_ids_; + const Tensor& expert_ids_; + const Tensor& max_token_id_; + index_t sorted_tile_size_; + const Tensor& a_t_k_; + const Tensor& b_e_n_k_; + Tensor& c_t_k_n_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CElementwiseOperation c_element_op_; + }; + + // Invoker + struct Invoker : public device::BaseInvoker + { + using Argument = ReferenceMoeGemm1BlockScaleSplitK::Argument; + + float Run(const Argument& arg) + { + const int full_n = arg.c_t_k_n_.mDesc.GetLengths()[2]; + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = arg.a_t_k_.mDesc.GetLengths()[1]; + AccDataType v_acc{0}; + + ComputeTypeA v_a{0}; + ComputeTypeB v_b{0}; + + const int t = arg.sorted_token_ids_(m) & 0xffffff; + const int topk_id = (arg.sorted_token_ids_(m) & 0xff000000) >> 24; + const int e = arg.expert_ids_(m / arg.sorted_tile_size_); + const int token_cnt = arg.a_t_k_.mDesc.GetLengths()[0]; + if(t < token_cnt) + { + for(int k = 0; k < K; ++k) + { + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.a_t_k_(t, k).data; + uint8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2 >> 0) & 0xf; + else + i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_a = i4_to_f32_gfx9(i4); +#else + v_a = i4 - 8; +#endif + } + else + { + arg.a_element_op_(v_a, arg.a_t_k_(t, k)); + } + // same for B matrix + if constexpr(is_same_v) + { + uint8_t i4x2 = arg.b_e_n_k_(e, k, n).data; + uint8_t i4 = 0; + if(k % 2 == 1) + { + i4 = (i4x2 >> 0) & 0xf; + } + else + { + i4 = (i4x2 >> 4) & 0xf; + } +#if CK_USE_PK4_LAYOUT_SHUFFLE + v_b = i4_to_f32_gfx9(i4); +#else + v_b = i4 - 8; +#endif + } + else + { + arg.b_element_op_(v_b, arg.b_e_n_k_(e, k, n)); + } + + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } + CDataType v_c{0}; + + arg.c_element_op_(v_c, v_acc); + arg.c_t_k_n_(t, topk_id, n) = v_c; + } + }; + + const ck::index_t max_token_id = arg.max_token_id_(0); + make_ParallelTensorFunctor(f_mk_kn_mn, max_token_id, full_n)( + std::thread::hardware_concurrency()); + + return 0; + } + + float Run(const device::BaseArgument* p_arg, + const StreamConfig& /* stream_config */ = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg)); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + bool IsSupportedArgument(const device::BaseArgument*) override { return true; } + + static auto MakeArgument(const Tensor& sorted_token_ids, + const Tensor& expert_ids, + const Tensor& max_token_id, + const index_t sorted_tile_size, + const Tensor& a_t_k, + const Tensor& b_e_n_k, + Tensor& c_t_k_n, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{sorted_token_ids, + expert_ids, + max_token_id, + sorted_tile_size, + a_t_k, + b_e_n_k, + c_t_k_n, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + virtual std::unique_ptr MakeInvokerPointer() + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "ReferenceMoeGemm1BlaockScaleSplitK" + << std::endl; + // clang-format on + + return str.str(); + } + + static float i4_to_f32_gfx9(uint8_t i4) + { + static std::unordered_map u = {{0b1000, -0.5000f}, + {0b1001, -0.4375f}, + {0b1010, -0.3750f}, + {0b1011, -0.3125f}, + {0b1100, -0.2500f}, + {0b1101, -0.1875f}, + {0b1110, -0.1250f}, + {0b1111, -0.0625f}, + {0b0, +0.0000f}, + {0b1, +0.0625f}, + {0b10, +0.1250f}, + {0b11, +0.1875f}, + {0b100, +0.2500f}, + {0b101, +0.3125f}, + {0b110, +0.3750f}, + {0b111, +0.4375f}}; + + return u[i4]; + } +}; + +} // namespace host +} // namespace tensor_operation +} // namespace ck