From cde0c4d22409cc61a10412c0840353eb52a3f75d Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Fri, 25 Jul 2025 03:49:58 +0200 Subject: [PATCH] Support b_scale: (#2350) - extend pipeline v1 and v3 - add instances - add tests - add example Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> [ROCm/composable_kernel commit: b01a27ff2203221a1180a984bc8efefd79d078fc] --- example/01_gemm/CMakeLists.txt | 2 + .../gemm_wmma_fp16_pk_i4_v3_b_scale.cpp | 367 ++++ .../blockwise_gemm_pipeline_wmmaops_base.hpp | 76 +- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 155 +- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 99 +- .../impl/device_gemm_wmma_cshuffle_v3.hpp | 231 +-- .../device_gemm_wmma_cshuffle_v3_b_scale.hpp | 302 ++++ .../device_gemm_wmma_cshuffle_v3_common.hpp | 265 +++ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1551 ++--------------- ...gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp | 551 ++++++ .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 1420 +++++++++++++++ .../gpu/gemm_b_scale.hpp | 24 +- .../gpu/gemm_b_scale/CMakeLists.txt | 6 +- ..._gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp | 72 + ...4_f16_mk_nk_mn_mem_v2_default_instance.cpp | 31 + ...e_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp | 4 +- .../profiler/profile_gemm_b_scale_impl.hpp | 4 +- profiler/src/CMakeLists.txt | 4 +- test/CMakeLists.txt | 1 + test/gemm_b_scale/CMakeLists.txt | 9 + .../test_gemm_b_scale_ut_cases.inc | 43 + test/gemm_b_scale/test_gemm_b_scale_util.hpp | 97 ++ test/gemm_b_scale/test_gemm_b_scale_wmma.cpp | 45 + test/gemm_b_scale/test_gemm_b_scale_xdl.cpp | 45 + 24 files changed, 3744 insertions(+), 1660 deletions(-) create mode 100644 example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp create mode 100644 test/gemm_b_scale/CMakeLists.txt create mode 100644 test/gemm_b_scale/test_gemm_b_scale_ut_cases.inc create mode 100644 test/gemm_b_scale/test_gemm_b_scale_util.hpp create mode 100644 test/gemm_b_scale/test_gemm_b_scale_wmma.cpp create mode 100644 test/gemm_b_scale/test_gemm_b_scale_xdl.cpp diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index e6a26ecafd..61f3ba5351 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -128,3 +128,5 @@ add_example_executable(example_gemm_wmma_fp16_pk_i4_v3 gemm_wmma_fp16_pk_i4_v3.c add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3) add_example_executable(example_gemm_wmma_fp16_fp8_v3 gemm_wmma_fp16_fp8_v3.cpp) add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_fp8_v3) +add_example_executable(example_gemm_wmma_fp16_pk_i4_v3_b_scale gemm_wmma_fp16_pk_i4_v3_b_scale.cpp) +add_example_dependencies(example_gemm_wmma example_gemm_wmma_fp16_pk_i4_v3_b_scale) diff --git a/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp new file mode 100644 index 0000000000..d3ac184019 --- /dev/null +++ b/example/01_gemm/gemm_wmma_fp16_pk_i4_v3_b_scale.cpp @@ -0,0 +1,367 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" + +using ADataType = ck::half_t; +using BDataType = ck::pk_i4_t; +using BScaleDataType = ck::half_t; +using AccDataType = float; +using CShuffleDataType = ck::half_t; +using CDataType = ck::half_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr bool PermuteA = false; +static constexpr bool PermuteB = true; + +static constexpr ck::index_t Scale_Block_N = 1; +static constexpr ck::index_t Scale_Block_K = 128; + +static constexpr ck::index_t KPerBlock = 64; + +// clang-format off +using DeviceGemmV2Instance = + ck::tensor_operation::device::DeviceGemm_BScale_Wmma_CShuffleV3< + ALayout, BLayout, CLayout, + ADataType, BDataType, BScaleDataType, CDataType, AccDataType, CShuffleDataType, + AElementOp, BElementOp, CElementOp, GemmDefault, + 256, Scale_Block_N, Scale_Block_K, + 128, 128, + KPerBlock, 8, 8, + 16, 16, + 4, 2, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + S<2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 8, 8, 0, + 1, 1, S<1, 32, 1, 8>, 8, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, + CDataType, CDataType, PermuteA, PermuteB>; + +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; +template +bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) +{ + using namespace ck::literals; + + auto M = problem_size.M; + auto N = problem_size.N; + auto K = problem_size.K; + auto StrideA = problem_size.StrideA; + auto StrideB = problem_size.StrideB; + auto StrideC = problem_size.StrideC; + auto KBatch = problem_size.KBatch; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if constexpr(std::is_same_v) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, ck::index_t stride, auto layout) { + if(stride == -1) + { + // give a chance if stride is -1, return a default packed stride + if constexpr(std::is_same_v) + { + return static_cast(col); + } + else + { + return static_cast(row); + } + } + else + return static_cast(stride); + }; + + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); + StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); + StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + BLayout{})); + + switch(config.init_method) + { + case 0: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 1: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 3: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + case 4: + a_m_k.GenerateTensorValue(GeneratorTensor_1{1}); + b_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 5: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{1}); + break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_3{0.5, 0.5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "c_m_n: " << c_m_n_host_result.mDesc << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_k_n_device_buf(sizeof(BDataType) * b_k_n_permute.mDesc.GetElementSpaceSize() / 2); + DeviceMem b1_scale_device_buf(sizeof(BScaleDataType) * b1_k_n.mDesc.GetElementSpaceSize()); + DeviceMem c_m_n_device_buf(sizeof(CDataType) * c_m_n_device_result.mDesc.GetElementSpaceSize()); + + // weight permute + if constexpr(PermuteB) + { + int K1 = KPerBlock; + int K0 = K / KPerBlock; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + b_k_n_permute(j * N * K1 + i * K1 + jj) = b_k_n(i * K + (j * K1 + jj)); + } + } + } + } + else + { + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j++) + { + b_k_n_permute(i * K + j) = b_k_n(i * K + j); + } + } + } + + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int input[8]; + + for(int k = 0; k < 4; k++) + { + int i4x2 = b_k_n_permute(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int hi = input[2]; + int lo = input[0]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 0, i) = i4x2; + } + + { + int hi = input[6]; + int lo = input[4]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 2, i) = i4x2; + } + + { + int hi = input[3]; + int lo = input[1]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 4, i) = i4x2; + } + + { + int hi = input[7]; + int lo = input[5]; + int i4x2 = (hi << 4) | lo; + + b_k_n_permute(j + 6, i) = i4x2; + } + } + } + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n_permute.mData.data()); + b1_scale_device_buf.ToDevice(b1_k_n.mData.data()); + DeviceMem workspace; + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CElementOp{}; + + // do GEMM + auto gemm = DeviceGemmV2Instance{}; + auto invoker = gemm.MakeInvoker(); + float ave_time = 0; + + auto argument = + gemm.MakeArgument(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + StrideC, + Scale_Stride_BN, + static_cast(b1_scale_device_buf.GetDeviceBuffer()), + KBatch, + a_element_op, + b_element_op, + c_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + std::cerr << gemm.GetTypeString() << " does not support this problem" << std::endl; + + return true; + } + + std::string device_name = ck::get_device_name(); + if(!(device_name.find("gfx11") != std::string::npos || + device_name.find("gfx12") != std::string::npos)) + { + std::cout << "This kernel support gfx1100 and gfx1200 only" << std::endl; + + return true; + } + + bool pass = true; + if(config.do_verification) + { + Tensor b_k_n_dequant({K, N}); + + float v_b = 0; + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + ck::pk_i4_t i4x2 = b_k_n(k, n).data; + int8_t i4 = 0; + if(k % 2 == 1) + i4 = (i4x2.data >> 0) & 0xf; + else + i4 = (i4x2.data >> 4) & 0xf; + i4 = i4 - 8; + v_b = ck::type_convert(i4); + + b_k_n_dequant(k, n) = + ck::type_convert(v_b) * + ck::type_convert(b1_k_n(k / Scale_Block_K, n / Scale_Block_N)); + } + } + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n_dequant, c_m_n_host_result, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0}); + c_m_n_device_buf.FromDevice(c_m_n_device_result.mData.data()); + + pass &= ck::utils::check_err(c_m_n_device_result, + c_m_n_host_result, + "Error: Incorrect results!", + get_rtol(), + get_atol()); + } + + if(config.time_kernel) + { + ave_time = + invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, 0, 20, 50, true, 50}); + + std::size_t flop = 2_uz * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + + sizeof(BDataType) * K * N / + (ck::is_same_v, ck::pk_i4_t> ? 2 : 1) + + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + return pass; +} + +bool run_gemm_splitk_example(int argc, char* argv[]) +{ + ProblemSizeSplitK problem_size; + ExecutionConfig config; + + return !parse_cmd_args(argc, argv, problem_size, config) || run_gemm(problem_size, config); +} + +int main(int argc, char* argv[]) { return !run_gemm_splitk_example(argc, argv); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index 14856f210c..d46c5b737d 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -91,6 +91,78 @@ struct BlockwiseGemmWmmaops_pipeline_base true> c_thread_buf_; + struct Empty + { + __device__ Empty(){}; + template + __device__ void GlobalLoad(bool cond) + { + ignore = NBuffer; + ignore = cond; + } + }; + + template + struct BScale + { + __device__ BScale(GridDesc b_scale_grid_desc_, + ThreadCopy b_scale_thread_copy_, + GridBuffer b_scale_grid_buf_) + : b_scale_thread_copy(b_scale_thread_copy_), + b_scale_grid_desc(b_scale_grid_desc_), + b_scale_grid_buf(b_scale_grid_buf_){}; + + static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; + + static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + + static constexpr auto b_scale_thread_copy_step = + make_tuple(make_multi_index(NWaves * NPerWmma, 0), + make_multi_index(-NPerBlock, 0), + make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + + template + __device__ void GlobalLoad(bool cond) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_scale_thread_copy.Run(b_scale_grid_desc, + b_scale_grid_buf, + b_scale_thread_desc, + make_tuple(n0, Number<0>{}), + b_scale_thread_bufs(Number{})); + + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<0>{})); + }); + + if(cond) + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<2>{})); + } + else + { + b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, + b_scale_thread_copy_step.At(Number<1>{})); + } + } + + ThreadCopy b_scale_thread_copy; + GridDesc b_scale_grid_desc; + GridBuffer b_scale_grid_buf; + StaticallyIndexedArray{}> b_scale_thread_bufs; + }; + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } __device__ static auto GetWaveIdx() @@ -285,7 +357,7 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -296,7 +368,7 @@ struct BlockwiseGemmWmmaops_pipeline_base ComputeTypeB, decltype(b_block_desc_k0_n0_n1_n2_k1), decltype(b_thread_desc_), - Sequence, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, B_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index df82e155be..f25648efa6 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -172,7 +175,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); @@ -186,6 +192,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -195,20 +203,42 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -258,6 +288,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1((i + 2) % num_loop_per_scale == 0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -378,6 +409,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -421,7 +455,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1( a_thread_desc_.GetElementSpaceSize()); @@ -435,6 +472,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -445,30 +484,57 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0_offset) { static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, - I0, - I0, - I0, - I0, - I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0_inner, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, - I0, - I0, - I0, - I0, - I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0_inner, I0, I0, I0), - b_thread_buf); + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{}, + m0, + I0, + I0, + I0, + I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0_inner, I0, I0, I0), + a_thread_buf); + }); + if constexpr(ck::is_same::value == true) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{}, + n0, + I0, + I0, + I0, + I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs(I0)[Number< + n0 * BScaleStruct::num_scale_k_block + + (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}], + b_thread_desc_, + make_tuple(I0, n0, k0_inner, I0, I0, I0), + b_thread_buf); + }); + } }); __builtin_amdgcn_sched_barrier(0); @@ -564,6 +630,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1((i + 2) % num_loop_per_scale == 0); a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -613,7 +680,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1, @@ -624,7 +691,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1, + Sequence, Sequence<0, 1, 2, 3, 4, 5>, 5, B_K1, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 5ceb8a6be4..8fed23d151 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3 + __device__ inline void LocalLoad(ABlockBuffer& a_block_buf, + AThreadBuffer& a_thread_buf, + BBlockBuffer& b_block_buf, + BThreadBuffer& b_thread_buf, + BScaleStruct& b_scale_struct) const + { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run( + a_block_desc_k0_m0_m1_m2_k1, + make_tuple(Number{}, m0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0), + a_thread_buf); + }); + + if constexpr(ck::is_same_v) + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + else + { + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run( + b_block_desc_k0_n0_n1_n2_k1, + make_tuple(Number{}, n0, I0, I0, I0, I0), + b_block_buf, + b_scale_struct.b_scale_thread_bufs( + I0)[Number{}], + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0), + b_thread_buf); + }); + } + }); + } + template + typename CThreadBuffer, + typename BScaleStruct> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -283,7 +338,10 @@ struct BlockwiseGemmWmmaops_pipeline_v3( @@ -298,6 +356,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop_per_scale == 1); + // Local prefill 1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); @@ -314,20 +374,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); - }); + + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); __builtin_amdgcn_sched_barrier(0); @@ -348,6 +396,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3((i + 2) % num_loop_per_scale == 0); + static_for<0, KRepeat, 1>{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { @@ -392,22 +442,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}([&](auto k0) { - a_thread_copy_.Run( - a_block_desc_k0_m0_m1_m2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - a_thread_buf); - b_thread_copy_.Run( - b_block_desc_k0_n0_n1_n2_k1, - make_tuple(Number{}, I0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, k0, I0, I0, I0), - b_thread_buf); - }); + LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp index 90afc467d4..a921962c67 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3.hpp @@ -16,6 +16,7 @@ #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { namespace tensor_operation { @@ -229,222 +230,28 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - - auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize; - auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize; - - ck::utility::RotatingMemWrapper rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(arg.KBatch > 1) - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - } - else - { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - // TODO: properly implement this check - return true; - } + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) - { - return false; - } - - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.KBatch > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) - { - return false; - } - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); + return DeviceGemmCommon::IsSupportedArgument(arg); } // polymorphic diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..1a68b35f1f --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,302 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale +{ + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + index_t GetKPerBlock() override { return KPerBlock; } + + bool GetPermuteB() override { return PermuteB; } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const BScaleDataType* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + return Argument{p_a, + p_b, + p_c, + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr MakeArgumentPointer(const void* p_a, + const void* p_b, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + index_t StrideC, + index_t StrideScaleB, + const void* p_b_scale, + index_t KBatch, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) override + { + return std::make_unique(static_cast(p_a), + static_cast(p_b), + static_cast(p_c), + M, + N, + K, + StrideA, + StrideB, + StrideC, + StrideScaleB, + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_Wmma_CShuffleV3_BScale" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemm_Wmma_CShuffleV3_Common +{ + + using Argument = typename GridwiseGemm::Argument; + + /// @brief Helper structure responsible for kernel invocation. + /// + /// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU + /// kernel function. It usually determines the launched grid size prepares kernel + /// arguments as well as perform specific kernel configuration selection based on + /// runtime arguments. + /// + /// @note If appropriately configured it may measure kernel execution time. + /// + struct Invoker : public BaseInvoker + { + /// @brief This function issues GPU kernel execution. + /// @param arg The GPU kernel arguments. + /// @param stream_config The HIP stream configuration helper structure. + /// @return The kernel's average execution time (if time measurement is + /// enabled). + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + if(stream_config.log_level_ > 0) + { + arg.Print(); + GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); + } + + if(!GridwiseGemm::CheckValidity(arg)) + { + throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); + } + + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); + + float ave_time = 0; + + index_t k_grain = arg.KBatch * KPerBlock; + index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; + + const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + const auto Run = [&](const auto& kernel) { + if(stream_config.flush_cache) + { + Argument arg_ = arg; + + const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( + arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); + const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( + arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); + + auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * + sizeof(ADataType) / GridwiseGemm::APackedSize; + auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * + sizeof(BDataType) / GridwiseGemm::BPackedSize; + + ck::utility::RotatingMemWrapper rotating_mem( + arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck::utility::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(arg_.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid, + 0, + arg_.M * arg_.N * sizeof(CDataType), + stream_config.stream_id_)); + }; + + ave_time = ck::utility::launch_and_time_kernel_with_preprocess( + stream_config, + run_flush_cache, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + arg_); + } + else + { + if(arg.KBatch > 1) + HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid, + 0, + arg.M * arg.N * sizeof(CDataType), + stream_config.stream_id_)); + + ave_time = launch_and_time_kernel( + stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); + } + }; + + constexpr index_t minimum_occupancy = []() { + if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) + { + return 2; + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; + } + else + { + return 1; + } + }(); + + if(has_main_k_block_loop) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + else + { + // TODO: Implement + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.KBatch > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding || + GemmSpec == GemmSpecialization::KPadding)) + { + return false; + } + + return GridwiseGemm::CheckValidity(arg); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f3354cd5dd..75f12d094e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -14,47 +14,10 @@ #include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" namespace ck { -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); - - GridwiseGemm::template Run( - karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, - karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg); -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - /// @brief \"Universal\" GEMM kernel with SplitK support. /// /// @par Overview @@ -207,391 +170,143 @@ template struct GridwiseGemm_wmma_cshuffle_v3 + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - static constexpr auto I4 = Number<4>{}; - static constexpr auto I5 = Number<5>{}; - static constexpr auto I6 = Number<6>{}; - static constexpr auto I7 = Number<7>{}; + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; - // K1 should be Number<...> - static constexpr auto AK0Number = Number{}; - static constexpr auto BK0Number = Number{}; - static constexpr auto AK1Number = Number{}; - static constexpr auto BK1Number = Number{}; + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; - static constexpr index_t KPack = math::max( - math::lcm(AK1Number, BK1Number), - WmmaSelector::selected_wmma - .k_per_wmma); + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; using ThisThreadBlock = ThisThreadBlock; - static constexpr index_t APackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - static constexpr index_t BPackedSize = []() { - if constexpr(is_same_v, pk_i4_t>) - return 2; - else - return 1; - }(); - - __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) - { - return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); - } - - __host__ static auto CalculateMPadded(index_t M) - { - return math::integer_least_multiple(M, MPerBlock); - } - - __host__ static auto CalculateNPadded(index_t N) - { - return math::integer_least_multiple(N, NPerBlock); - } - - __host__ static auto CalculateKPadded(index_t K) - { - return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; - } - - __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); - } - - __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); - } - - __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) - { - auto K_t = K_Batch * KPerBlock; - return (K + K_t - 1) / K_t * KPerBlock; - } - - __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = K_Batch * KReadVec; - return (K + K_t - 1) / K_t * KReadVec; - } - - __host__ static auto CalculateMBlock(index_t M) - { - return math::integer_divide_ceil(M, MPerBlock); - } - - __host__ static auto CalculateNBlock(index_t N) - { - return math::integer_divide_ceil(N, NPerBlock); - } - - template - __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) - { - // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 - constexpr auto K0 = BlockDesc{}.GetLength(I0); - constexpr auto K1 = BlockDesc{}.GetLength(I2); -#ifdef __gfx12__ - constexpr auto KRow = I2; -#else - constexpr auto KRow = I1; -#endif - return transform_tensor_descriptor( - BlockDesc{}, - make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); - } - - __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( - index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) - { - const auto a_grid_desc_mraw_kraw = [&]() { - if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); - } - else if constexpr(is_same_v) - { - return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both M and K - const auto a_grid_desc_m_k = - transform_tensor_descriptor(a_grid_desc_mraw_kraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(MPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad M, but not K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_right_pad_transform(M, MPad - M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad K, but not M - const auto a_grid_desc_m_k = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_m_k, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - else - { - static_assert(!PermuteA, "PermuteA is not supported"); - - // not pad M or K - const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), - make_pass_through_transform(M)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return a_grid_desc_ak0_m_ak1; - } - } - - __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( - index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) - { - const auto b_grid_desc_nraw_kraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); - } - }(); - - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - static_assert(!(is_same_v, pk_i4_t> && - GemmSpec != GemmSpecialization::Default), - "pk_i4_t does not support padding"); - - if constexpr(GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad both N and K - const auto b_grid_desc_n_k = - transform_tensor_descriptor(b_grid_desc_nraw_kraw, - make_tuple(make_right_pad_transform(N, NPad - N), - make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(NPad)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::MNPadding) - { - // pad N, but not K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else if constexpr(GemmSpec == GemmSpecialization::KPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad K, but not N - const auto b_grid_desc_n_k = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_n_k, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - if constexpr(!PermuteB) - { - // not pad N or K - const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), - make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); - - return b_grid_desc_bk0_n_bk1; - } - else - { - // Pre-shuffled Weight - // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] - constexpr index_t BK01 = KPerBlock / BK1Value; - const index_t BK0_ = StrideB / BK1Value; - const index_t BK00 = BK0_ / BK01; - - const auto b_grid_desc_bk00_n_bk01_bk1_permute = - make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); - - const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( - b_grid_desc_bk00_n_bk01_bk1_permute, - make_tuple(make_merge_transform(make_tuple(BK00, BK01)), - make_pass_through_transform(make_tuple(N)), - make_pass_through_transform(BK1Value)), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_grid_desc_bk0_n_bk1_permute; - } - } - } - - template - __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - - return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); - } - - template - __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) - { - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); - } - - __host__ __device__ static auto - MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) - { - const auto c_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); - } - }(); - - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - // TODO: Investigate why this path is not used in the original - // gridwise_gemm_xdl_cshuffle_v3.hpp -#if 0 - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - - if constexpr(GemmSpec == GemmSpecialization::MNPadding || - GemmSpec == GemmSpecialization::MNKPadding) - { - // pad M and N - return transform_tensor_descriptor(c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), - make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::MPadding || - GemmSpec == GemmSpecialization::MKPadding) - { - // pad M, but not N - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else if constexpr(GemmSpec == GemmSpecialization::NPadding || - GemmSpec == GemmSpecialization::NKPadding) - { - // pad N, but not M - return transform_tensor_descriptor( - c_grid_desc_mraw_nraw, - make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0>{}, Sequence<1>{})); - } - else - { - // not pad M or N - return c_grid_desc_mraw_nraw; - } -#endif - } + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; struct Problem { @@ -749,943 +464,14 @@ struct GridwiseGemm_wmma_cshuffle_v3 index_t c_reduce_offset; }; - __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() - { - // A matrix in LDS memory, dst of blockwise copy - if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{} * AK1Number, AK1Number, I1)); - } - // xor tensor transformation request more unnecessary vgpr usage, would cause register spill - // in some cases. - else if constexpr(is_same::value) - { - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; - constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); - constexpr auto M1 = MPerBlock / M0; - - constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); - constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / MPerWmma; - constexpr auto K0PerThreadRead = AK0Number / KThreadRead; - - constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) - ? 1 - : 128 / (AK1Number * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=n0 - constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) - ? 1 - : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 - ? M0 - : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - AK1Number)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; - } - } - - __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() - { - // B matrix in LDS memory, dst of blockwise copy - if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - // bank conflict when writting the data into LDS, but don't worry, we have whole entire - // loop to hide it in v4. it may give you some benefit from less valu in compute address - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{} * BK1Number, BK1Number, I1)); - } - else if constexpr(is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; - constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - BK0Number * Number{}, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<1, 0>{}, Sequence<2>{}), - make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - - constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_bk0_nldslayer_n_bk1, - make_tuple(make_pass_through_transform(BK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - else // RowMajor B - { - constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); - constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; - constexpr auto KThreadRead = 64 / NPerWmma; - constexpr auto K0PerThreadRead = BK0Number / KThreadRead; - - constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) - ? 1 - : 128 / (BK1Number * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=n0 - constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) - ? 1 - : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 - ? N0 - : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, - Number{}, - Number{}, - Number{}, - Number{}, - BK1Number)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_xor_with_modulo_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), - make_tuple( - Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(Number{}), - make_pass_through_transform(Number{}), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple(Number{}, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0>{}, - Sequence<1>{}, - Sequence<2>{}, - Sequence<3>{}, - Sequence<4>{}, - Sequence<5>{}), - make_tuple(Sequence<1>{}, - Sequence<2>{}, - Sequence<0, 3>{}, - Sequence<4, 5>{}, - Sequence<6>{}, - Sequence<7>{})); - - constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(Number{}, - Number{}, - Number{}, - Number{})), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(BK1Number)), - make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return b_lds_block_desc_bk0_n_bk1; - } - } - - __host__ __device__ static constexpr auto - // *Caution Here repeat is shuffle repeat - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() - { - constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - make_naive_tensor_descriptor_packed( - make_tuple(I1, - Number{}, - I1, - Number{})); - - return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; - } - - using BlockwiseGemmPipe = remove_cvref_t< - decltype(BlockGemmPipeline_Selector< - BlkGemmPipelineVer, - BlkGemmPipeSched, - BlockSize, - ADataType, - BDataType, - ComputeTypeA, - ComputeTypeB, - AccDataType, - decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), - decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), - ABlockTransferSrcScalarPerVector, - BBlockTransferSrcScalarPerVector, - MPerBlock, - NPerBlock, - KPerBlock, - MPerWmma, - NPerWmma, - MRepeat, - NRepeat, - KPack>())>; - - __device__ static constexpr index_t GetSharedMemoryNumberOfByte() - { - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - constexpr auto b_block_space_size_aligned = math::integer_least_multiple( - b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); - - // LDS allocation for C shuffle in LDS - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - constexpr auto c_block_size = - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize(); - - return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + - b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), - c_block_size * sizeof(CShuffleDataType)); - } - - // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} - __host__ static constexpr bool CheckValidity(const Argument& karg) - { - static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && - (NPerBlock % (NPerWmma * NRepeat)) == 0, - "Invalid tuning param!"); - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - !(is_same::value)) - { - if(!(karg.M % MPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && - (is_same::value)) - { - if(!(karg.N % NPerBlock == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || - GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) - { - - auto K_t = karg.KBatch * KPerBlock; - if(!(karg.K % K_t == 0)) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); - auto K_t = karg.KBatch * KReadVec; - auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; - if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) - { - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.K % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.M % ABlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" - << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - else - { - if(karg.K % BBlockTransferSrcScalarPerVector != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg K (" << karg.K - << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" - << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same::value) - { - if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg N (" << karg.N - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - else - { - if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Arg M (" << karg.M - << ") value is not a multiple of " - "CShuffleBlockTransferScalarPerVector_NPerBlock (" - << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - return false; - } - } - - if constexpr(!(is_same, half_t>::value || - is_same, float>::value || - is_same, bhalf_t>::value || - is_same, int32_t>::value)) - { - if(!karg.IsReduceAdd()) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" - << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } - if(karg.KBatch > 1) - { - return false; - } - } - } - - // check gridwise gemm pipeline - const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); - - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) - { - if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) - { - return false; - } - } - - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return true; - } - - __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockHasHotloop(num_loop); - } - - __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) - { - const index_t num_loop = K / KPerBlock; - - return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); - } - - template - __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) - { - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( - c_grid_desc_m_n, - make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), - make_unmerge_transform(make_tuple(NBlock, Number{}))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); - - return c_grid_desc_mblock_mperblock_nblock_nperblock; - } + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; // return block_id to C matrix tile idx (m0, n0) mapping // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template - __device__ static void Run(const ADataType* p_a_grid, - const BDataType* p_b_grid, - CDataType* p_c_grid, - void* p_shared, - const Problem& problem, - const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) - { - const auto a_grid_buf = make_dynamic_buffer( - p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); - const auto b_grid_buf = make_dynamic_buffer( - p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); - auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; - const CElementwiseOperation c_element_op{}; - - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - - const auto block_work_idx = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - if(!block_2_ctile_map.ValidCTileIndex( - block_work_idx, - make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), - c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) - { - return; - } - - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); - - // HACK: this force m/n_block_data_idx_on_grid into SGPR - const index_t m_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); - - const index_t n_block_data_idx_on_grid = - __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); - - // lds max alignment - constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); - - // A matrix in LDS memory, dst of blockwise copy - constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); - - // B matrix in LDS memory, dst of blockwise copy - constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - - // A matrix blockwise copy - auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // B matrix blockwise copy - auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( - b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); - - // LDS allocation for A and B: be careful of alignment - constexpr auto a_block_space_size_aligned = math::integer_least_multiple( - a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); - - // Cast after lds - auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); - - auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType) / - APackedSize), - b_block_desc_bk0_n_bk1.GetElementSpaceSize()); - - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); - - // Blockwise GEMM pipeline - static_assert(std::is_default_constructible_v); - auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; - auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); - - const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); - - blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_buf, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_buf, - b_block_slice_copy_step, - c_thread_buf, - num_k_block_main_loop); - - // shuffle C and write out - { - // C mapping in single thread. - constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - blockwise_gemm_pipeline - .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - // C mapping in single block - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = - blockwise_gemm_pipeline - .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); - - constexpr auto MWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I1); - constexpr auto MSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I2); - constexpr auto NWave = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I4); - constexpr auto NThreadPerSubGroup = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I5); - constexpr auto MAccVgprs = - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp - .GetLength(I6); - - // LDS descriptor, shuffle and write out in MRepeat x NRepeat times - constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = - GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); - - auto c_shuffle_block_buf = make_dynamic_buffer( - static_cast(p_shared), - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat - .GetElementSpaceSize()); - - constexpr auto - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = - transform_tensor_descriptor( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_tuple( - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // MRepeat per shuffle repeat - MWave, // MWave - MSubGroup, // MSubGroup * MAccVgprs = MPerWmma - MAccVgprs)), - make_freeze_transform(I0), - make_unmerge_transform(make_tuple( - Number{}, // NRepeat per shuffle repeat - NWave, // NWave - NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple(Sequence<>{}, - Sequence<0, 1, 2, 6>{}, - Sequence<>{}, - Sequence<3, 4, 5>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); - - const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; - const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - - const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - MRepeat, MWave, MSubGroup, MAccVgprs))), - make_tuple(Sequence<0, 1, 2, 3>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor - .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( - NRepeat, NWave, NThreadPerSubGroup))), - make_tuple(Sequence<0, 1, 2>{}), - make_tuple(Sequence<0>{})); - - const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor - .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); - - // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< - AccDataType, - CShuffleDataType, - decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), - ck::tensor_operation::element_wise::PassThrough, - Sequence, - Sequence<0, 1, 2, 3, 4, 5, 6>, - 6, - 1, // vector write pixel - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - make_multi_index(0, - m_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - 0, - n_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3]), - ck::tensor_operation::element_wise::PassThrough{}}; - - // shuffle: blockwise copy C from LDS to global - auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< - ThisThreadBlock, // ThreadGroup - CElementwiseOperation, // ElementwiseOperation, - CGlobalMemoryDataOperation, // DstInMemOp, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, - CShuffleDataType, // typename SrcData, - CDataType, // typename DstData, - decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), - decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), - Sequence<0, 1, 2, 3>, // typename DimAccessOrder, - 3, // index_t VectorDim, - CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, - true, // bool ThreadTransferSrcResetCoordinateAfterRun, - false> // bool ThreadTransferDstResetCoordinateAfterRun> - {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - make_multi_index(0, 0, 0, 0), - c_grid_desc_mblock_mperblock_nblock_nperblock, - make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0), - c_element_op}; - - // space filling curve for local reg & global memory - // space filling curve for threadwise C in VGPR - constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6>, - Sequence>{}; - - // space filling curve for shuffled blockwise C in global mem - constexpr auto sfc_c_global = - SpaceFillingCurve, - Sequence<0, 2, 1, 3>, - Sequence<1, - CShuffleMRepeatPerShuffle * MWave * MPerWmma, - 1, - CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; - - constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); - - static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); - - static_for<0, num_access, 1>{}([&](auto access_id) { - // make sure it's safe to write to LDS - block_sync_lds(); - - // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run( - c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - sfc_c_vgpr.GetIndexTupleOfNumber(access_id), - c_thread_buf, - c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, - c_shuffle_block_buf); - - // make sure it's safe to read from LDS - block_sync_lds(); - - // each block copy its data from LDS to global - c_shuffle_block_copy_lds_to_global.Run( - c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, - c_shuffle_block_buf, - c_grid_desc_mblock_mperblock_nblock_nperblock, - c_grid_buf); - - if constexpr(access_id < num_access - 1) - { - constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); - - // move on C - c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( - c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); - } - }); - } - } + __device__ static index_t GetKBlockPerScale() { return 1; } template (p_a_grid, - p_b_grid, - p_c_grid, - p_shared, - problem, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock); + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct (Empty) + using BScale = typename BlockwiseGemmPipe::Empty; + auto b_scale_struct = BScale{}; + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp new file mode 100644 index 0000000000..7b6ad5ca3e --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -0,0 +1,551 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { + +template +struct GridwiseGemm_wmma_cshuffle_v3_b_scale + : GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB> +{ + using BScaleType = ck::half_t; + + using Base = GridwiseGemm_wmma_cshuffle_v3_base< + ALayout, + BLayout, + CLayout, + ADataType, + BDataType, + AccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1Value, + BK1Value, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + AThreadTransferSrcResetCoordinateAfterRun, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + BThreadTransferSrcResetCoordinateAfterRun, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; + using Base::I4; + using Base::I5; + using Base::I6; + using Base::I7; + + using Base::AK0Number; + using Base::AK1Number; + using Base::BK0Number; + using Base::BK1Number; + + using Base::APackedSize; + using Base::BPackedSize; + + using Base::CalculateAK0Padded; + using Base::CalculateBK0Padded; + using Base::CalculateKPadded; + using Base::CalculateKRead; + using Base::CalculateMBlock; + using Base::CalculateMPadded; + using Base::CalculateNBlock; + using Base::CalculateNPadded; + using Base::MakeAGridDescriptor_AK0_M_AK1; + using Base::MakeBGridDescriptor_BK0_N_BK1; + using Base::MakeCGridDescriptor_M_N; + + using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat; + + using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; + + using ThisThreadBlock = ThisThreadBlock; + + using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1; + using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1; + + struct Problem + { + __host__ Problem(index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + index_t KBatch_) + : M{M_}, + N{N_}, + K{K_}, + StrideA{StrideA_}, + StrideB{StrideB_}, + StrideC{StrideC_}, + StrideScaleB{StrideScaleB_}, + KBatch{KBatch_}, + MPadded{CalculateMPadded(M_)}, + NPadded{CalculateNPadded(N_)}, + KRead{CalculateKRead(K_, KBatch_)}, + KPadded{CalculateKPadded(K_, KBatch_)}, + AK0{CalculateAK0Padded(K_, KBatch_)}, + BK0{CalculateBK0Padded(K_, KBatch_)}, + MBlock{CalculateMBlock(M_)}, + NBlock{CalculateNBlock(N_)} + { + } + + __host__ void Print() const + { + std::cout << "problem {" + << "M:" << M << ", " + << "N:" << N << ", " + << "K:" << K << ", " + << "SA:" << StrideA << ", " + << "SB:" << StrideB << ", " + << "SC:" << StrideC << ", " + << "SScaleB:" << StrideScaleB << ", " + << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " + << "KRead:" << KRead << ", " + << "KP:" << KPadded << ", " + << "AK0:" << AK0 << ", " + << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " + << "NBlock: " << NBlock << "}" << std::endl; + } + + index_t M; + index_t N; + index_t K; + index_t StrideA; + index_t StrideB; + index_t StrideC; + index_t StrideScaleB; + index_t KBatch; + index_t MPadded; + index_t NPadded; + index_t KRead; + index_t KPadded; + index_t AK0; + index_t BK0; + index_t MBlock; + index_t NBlock; + }; + + // Argument + struct Argument : public tensor_operation::device::BaseArgument, public Problem + { + __host__ Argument(const ADataType* p_a_grid_, + const BDataType* p_b_grid_, + CDataType* p_c_grid_, + index_t M_, + index_t N_, + index_t K_, + index_t StrideA_, + index_t StrideB_, + index_t StrideC_, + index_t StrideScaleB_, + const BScaleType* p_b_scale_grid_, + index_t k_batch_, + AElementwiseOperation a_element_op_, + BElementwiseOperation b_element_op_, + CElementwiseOperation c_element_op_, + bool is_reduce_ = false) + : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_}, + p_a_grid{p_a_grid_}, + p_b_grid{p_b_grid_}, + p_c_grid{p_c_grid_}, + p_b_scale_grid{p_b_scale_grid_}, + a_element_op{a_element_op_}, + b_element_op{b_element_op_}, + c_element_op{c_element_op_}, + is_reduce(is_reduce_) + { + } + + __host__ __device__ inline bool IsReduceAdd() const + { + return (Problem::KBatch > 1) && is_reduce; + } + + __host__ __device__ inline bool IsAtomicAdd() const + { + return (Problem::KBatch > 1) && (!is_reduce); + } + + const ADataType* p_a_grid; + const BDataType* p_b_grid; + CDataType* p_c_grid; + + const BScaleType* p_b_scale_grid; + const AElementwiseOperation a_element_op; + const BElementwiseOperation b_element_op; + const CElementwiseOperation c_element_op; + bool is_reduce; + }; + + struct SplitKBatchOffset + { + + __device__ SplitKBatchOffset(Argument& karg) + { + if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead / APackedSize; + } + else if constexpr(is_same_v) + { + a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; + } + + if constexpr(is_same_v) + { + b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; + } + else if constexpr(is_same_v) + { + if constexpr(!PermuteB) + { + b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; + } + else + { + const int k0_offset = karg.KRead * karg.N; + b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; + } + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB; + } + else if constexpr(is_same_v) + { + scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + + if(blockIdx.z < static_cast(karg.KBatch - 1)) + { + karg.K = karg.KRead; + } + else + { + karg.K = karg.K - karg.KRead * (karg.KBatch - 1); + } + + if(karg.IsReduceAdd()) + { + c_reduce_offset = blockIdx.z * karg.M * karg.N; + } + else + { + c_reduce_offset = 0; + } + } + + index_t a_k_split_offset; + index_t b_k_split_offset; + index_t scale_k_split_offset; // New member for scale matrix offset + index_t c_reduce_offset; + }; + + using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + template + __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, + const BScaleType* p_b_scale_grid, + index_t block_n_id) + { + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + static constexpr auto wmma = + WmmaSelector{}; + static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + + static constexpr auto ScaleSliceSizeN = NRepeat; + static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; + auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + 1, + ScaleSliceSizeK, + 1, + false>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, + b_thread_offset_k / ScaleBlockK)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template BScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } + + __device__ static index_t GetKBlockPerScale() + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + const BScaleType* p_b_scale_grid, + void* p_shared, + const Problem& problem) + { + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1( + problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + + // B Scale grid + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( + make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), + math::integer_divide_ceil(problem.K, ScaleBlockK)), + make_tuple(problem.StrideScaleB, 1)); + + // divide block work by [M, N] + const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; + + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + if(!block_2_ctile_map.ValidCTileIndex( + block_work_idx, + make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0), + c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2)))) + { + return; + } + + const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); + const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + + // BScale struct + auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + + const index_t num_k_block_per_scale = GetKBlockPerScale(); + + Base::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + c_grid_desc_mblock_mperblock_nblock_nperblock, + block_m_id, + block_n_id, + num_k_block_per_scale, + b_scale_struct); + } + + // NOTE: Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void + Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg) + { + Run( + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_c_grid + splitk_batch_offset.c_reduce_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_shared, + karg); + } +}; + +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp new file mode 100644 index 0000000000..5a4a41e507 --- /dev/null +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -0,0 +1,1420 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck/utility/env.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + +template +struct GridwiseGemm_wmma_cshuffle_v3_base +{ + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto AK0Number = Number{}; + static constexpr auto BK0Number = Number{}; + static constexpr auto AK1Number = Number{}; + static constexpr auto BK1Number = Number{}; + + static constexpr index_t KPack = math::max( + math::lcm(AK1Number, BK1Number), + WmmaSelector::selected_wmma + .k_per_wmma); + + using ThisThreadBlock = ThisThreadBlock; + + static constexpr index_t APackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + static constexpr index_t BPackedSize = []() { + if constexpr(is_same_v, pk_i4_t>) + return 2; + else + return 1; + }(); + + // return block_id to C matrix tile idx (m0, n0) mapping + // if arch = gfx942 + using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; + // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; + + __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch) + { + return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch); + } + + __host__ static auto CalculateMPadded(index_t M) + { + return math::integer_least_multiple(M, MPerBlock); + } + + __host__ static auto CalculateNPadded(index_t N) + { + return math::integer_least_multiple(N, NPerBlock); + } + + __host__ static auto CalculateKPadded(index_t K) + { + return math::integer_divide_ceil(K, KPerBlock) * KPerBlock; + } + + __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / AK1Value); + } + + __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * (KPerBlock / BK1Value); + } + + __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1) + { + auto K_t = K_Batch * KPerBlock; + return (K + K_t - 1) / K_t * KPerBlock; + } + + __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1) + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = K_Batch * KReadVec; + return (K + K_t - 1) / K_t * KReadVec; + } + + __host__ static auto CalculateMBlock(index_t M) + { + return math::integer_divide_ceil(M, MPerBlock); + } + + __host__ static auto CalculateNBlock(index_t N) + { + return math::integer_divide_ceil(N, NPerBlock); + } + + template + __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&) + { + // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1 + constexpr auto K0 = BlockDesc{}.GetLength(I0); + constexpr auto K1 = BlockDesc{}.GetLength(I2); +#ifdef __gfx12__ + constexpr auto KRow = I2; +#else + constexpr auto KRow = I1; +#endif + return transform_tensor_descriptor( + BlockDesc{}, + make_tuple(make_unmerge_transform(make_tuple(Number{}, KRow)), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{})); + } + + __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1( + index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0) + { + const auto a_grid_desc_mraw_kraw = [&]() { + if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); + } + else if constexpr(is_same_v) + { + return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both M and K + const auto a_grid_desc_m_k = + transform_tensor_descriptor(a_grid_desc_mraw_kraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(MPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad M, but not K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_right_pad_transform(M, MPad - M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad K, but not M + const auto a_grid_desc_m_k = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_m_k, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + else + { + static_assert(!PermuteA, "PermuteA is not supported"); + + // not pad M or K + const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_grid_desc_mraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_pass_through_transform(M)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return a_grid_desc_ak0_m_ak1; + } + } + + __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1( + index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0) + { + const auto b_grid_desc_nraw_kraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); + } + }(); + + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + static_assert(!(is_same_v, pk_i4_t> && + GemmSpec != GemmSpecialization::Default), + "pk_i4_t does not support padding"); + + if constexpr(GemmSpec == GemmSpecialization::NKPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad both N and K + const auto b_grid_desc_n_k = + transform_tensor_descriptor(b_grid_desc_nraw_kraw, + make_tuple(make_right_pad_transform(N, NPad - N), + make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(NPad)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::MNPadding) + { + // pad N, but not K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else if constexpr(GemmSpec == GemmSpecialization::KPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad K, but not N + const auto b_grid_desc_n_k = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_pass_through_transform(N), make_right_pad_transform(K, KPad - K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_n_k, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + if constexpr(!PermuteB) + { + // not pad N or K + const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_grid_desc_nraw_kraw, + make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), + make_pass_through_transform(N)), + make_tuple(Sequence<1>{}, Sequence<0>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return b_grid_desc_bk0_n_bk1; + } + else + { + // Pre-shuffled Weight + // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1] + constexpr index_t BK01 = KPerBlock / BK1Value; + const index_t BK0_ = StrideB / BK1Value; + const index_t BK00 = BK0_ / BK01; + + const auto b_grid_desc_bk00_n_bk01_bk1_permute = + make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value)); + + const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor( + b_grid_desc_bk00_n_bk01_bk1_permute, + make_tuple(make_merge_transform(make_tuple(BK00, BK01)), + make_pass_through_transform(make_tuple(N)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc_bk0_n_bk1_permute; + } + } + } + + template + __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&) + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + + return MakeWmmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + } + + template + __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&) + { + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + return MakeWmmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + } + + __host__ __device__ static auto + MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC) + { + const auto c_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); + } + }(); + + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + // TODO: Investigate why this path is not used in the original + // gridwise_gemm_xdl_cshuffle_v3.hpp +#if 0 + using GemmSpecialization = tensor_operation::device::GemmSpecialization; + + if constexpr(GemmSpec == GemmSpecialization::MNPadding || + GemmSpec == GemmSpecialization::MNKPadding) + { + // pad M and N + return transform_tensor_descriptor(c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), + make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::MPadding || + GemmSpec == GemmSpecialization::MKPadding) + { + // pad M, but not N + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_right_pad_transform(M, MPad - M), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else if constexpr(GemmSpec == GemmSpecialization::NPadding || + GemmSpec == GemmSpecialization::NKPadding) + { + // pad N, but not M + return transform_tensor_descriptor( + c_grid_desc_mraw_nraw, + make_tuple(make_pass_through_transform(M), make_right_pad_transform(N, NPad - N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + } + else + { + // not pad M or N + return c_grid_desc_mraw_nraw; + } +#endif + } + + __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() + { + // A matrix in LDS memory, dst of blockwise copy + if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{} * AK1Number, AK1Number, I1)); + } + // xor tensor transformation request more unnecessary vgpr usage, would cause register spill + // in some cases. + else if constexpr(is_same::value) + { + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize; + constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + AK0Number * Number{}, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_ak0_mldslayer_m_ak1, + make_tuple(make_pass_through_transform(AK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + else // ColumnMajor A + { + // kfold and mpair dimension is not always required. + // more dimension in merge_transform increase the difficulty of generating immarg offset + // for compiler. + constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto M1 = MPerBlock / M0; + + constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / MPerWmma; + constexpr auto K0PerThreadRead = AK0Number / KThreadRead; + + constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128) + ? 1 + : 128 / (AK1Number * M0 * sizeof(ADataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=mpair<=n0 + constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128) + ? 1 + : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0 + ? M0 + : 128 / (AK1Number * MPerWmma * sizeof(ADataType))); + + constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + AK1Number)); + + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( + a_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(AK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_lds_block_desc_ak0_m_ak1; + } + } + + __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1() + { + // B matrix in LDS memory, dst of blockwise copy + if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) + { + // bank conflict when writting the data into LDS, but don't worry, we have whole entire + // loop to hide it in v4. it may give you some benefit from less valu in compute address + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{} * BK1Number, BK1Number, I1)); + } + else if constexpr(is_same::value) + { + // NLdsLayer * K0 as logical Bank + constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize; + constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize; + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( + make_tuple( + BK0Number * Number{}, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple(make_xor_with_modulo_transform(make_tuple( + Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform(make_tuple(BK0Number, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_bk0_nldslayer_n_bk1, + make_tuple(make_pass_through_transform(BK0Number), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + else // RowMajor B + { + constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1); + constexpr auto N1 = NPerBlock / N0; + + constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0); + constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite; + constexpr auto KThreadRead = 64 / NPerWmma; + constexpr auto K0PerThreadRead = BK0Number / KThreadRead; + + constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128) + ? 1 + : 128 / (BK1Number * N0 * sizeof(BDataType)); + constexpr auto KThreadReadPerm = + (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 + ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) + : KThreadRead; + + // 1<=npair<=n0 + constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128) + ? 1 + : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0 + ? N0 + : 128 / (BK1Number * NPerWmma * sizeof(BDataType))); + + constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + BK1Number)); + + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{}), + make_tuple( + Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}, Sequence<5>{})); + + constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple( + make_pass_through_transform(Number{}), + make_pass_through_transform(Number{}), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_unmerge_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{}), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<1>{}, + Sequence<2>{}, + Sequence<0, 3>{}, + Sequence<4, 5>{}, + Sequence<6>{}, + Sequence<7>{})); + + constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor( + b_lds_block_desc_unmerged, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(Number{}, + Number{}, + Number{}, + Number{})), + make_merge_transform_v3_division_mod( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(BK1Number)), + make_tuple(Sequence<0, 1, 4, 2>{}, Sequence<5, 6, 3>{}, Sequence<7>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_lds_block_desc_bk0_n_bk1; + } + } + + __host__ __device__ static constexpr auto + // *Caution Here repeat is shuffle repeat + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat() + { + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + make_naive_tensor_descriptor_packed( + make_tuple(I1, + Number{}, + I1, + Number{})); + + return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat; + } + + using BlockwiseGemmPipe = remove_cvref_t< + decltype(BlockGemmPipeline_Selector< + BlkGemmPipelineVer, + BlkGemmPipeSched, + BlockSize, + ADataType, + BDataType, + ComputeTypeA, + ComputeTypeB, + AccDataType, + decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), + decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), + ABlockTransferSrcScalarPerVector, + BBlockTransferSrcScalarPerVector, + MPerBlock, + NPerBlock, + KPerBlock, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + KPack>())>; + + template + __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock) + { + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor( + c_grid_desc_m_n, + make_tuple(make_unmerge_transform(make_tuple(MBlock, Number{})), + make_unmerge_transform(make_tuple(NBlock, Number{}))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{})); + + return c_grid_desc_mblock_mperblock_nblock_nperblock; + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + template + __host__ static constexpr bool CheckValidity(const Argument& karg) + { + static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && + (NPerBlock % (NPerWmma * NRepeat)) == 0, + "Invalid tuning param!"); + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + !(is_same::value)) + { + if(!(karg.M % MPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding) && + (is_same::value)) + { + if(!(karg.N % NPerBlock == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::NKPadding || + GemmSpec == tensor_operation::device::GemmSpecialization::MNKPadding)) + { + + auto K_t = karg.KBatch * KPerBlock; + if(!(karg.K % K_t == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " + << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + constexpr auto KReadVec = math::lcm(AK1Number, BK1Number); + auto K_t = karg.KBatch * KReadVec; + auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; + if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) + { + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.K % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.M % ABlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of ABlockTransferSrcScalarPerVector (" + << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + else + { + if(karg.K % BBlockTransferSrcScalarPerVector != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K (" << karg.K + << ") value is not a multiple of BBlockTransferSrcScalarPerVector (" + << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + return false; + } + } + + if constexpr(is_same::value) + { + if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg N (" << karg.N + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + else + { + if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg M (" << karg.M + << ") value is not a multiple of " + "CShuffleBlockTransferScalarPerVector_NPerBlock (" + << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + + if constexpr(!(is_same, half_t>::value || + is_same, float>::value || + is_same, bhalf_t>::value || + is_same, int32_t>::value)) + { + if(!karg.IsReduceAdd()) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet" + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + if(karg.KBatch > 1) + { + return false; + } + } + } + + // check gridwise gemm pipeline + const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); + + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockHasHotloop(num_loop); + } + + __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K) + { + const index_t num_loop = K / KPerBlock; + + return BlockwiseGemmPipe::BlockLoopTailNum(num_loop); + } + + __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size_aligned = math::integer_least_multiple( + b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align); + + // LDS allocation for C shuffle in LDS + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + constexpr auto c_block_size = + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize(); + + return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize + + b_block_space_size_aligned * sizeof(BDataType) / BPackedSize), + c_block_size * sizeof(CShuffleDataType)); + } + + template + __device__ static void Run(const ADataType* p_a_grid, + const BDataType* p_b_grid, + CDataType* p_c_grid, + void* p_shared, + const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, + const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, + const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t& block_m_id, + const index_t& block_n_id, + const index_t& num_k_block_per_scale, + BScaleStruct& b_scale_struct) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + const AElementwiseOperation a_element_op{}; + const BElementwiseOperation b_element_op{}; + const CElementwiseOperation c_element_op{}; + + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, m_block_data_idx_on_grid, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // B matrix blockwise copy + auto b_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true, + BlockwiseGemmPipe::GlobalBufferNum>( + b_grid_desc_bk0_n_bk1, + make_multi_index(0, n_block_data_idx_on_grid, 0), + b_element_op, + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size_aligned = math::integer_least_multiple( + a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align); + + // Cast after lds + auto a_block_buf = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + + auto b_block_buf = make_dynamic_buffer( + reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * + sizeof(ADataType) / + APackedSize), + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_buf, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_buf, + b_block_slice_copy_step, + c_thread_buf, + b_scale_struct, + num_k_block_main_loop, + num_k_block_per_scale); + + // shuffle C and write out + { + // C mapping in single thread. + constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + blockwise_gemm_pipeline + .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + // C mapping in single block + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp = + blockwise_gemm_pipeline + .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs(); + + constexpr auto MWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I1); + constexpr auto MSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I2); + constexpr auto NWave = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I4); + constexpr auto NThreadPerSubGroup = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I5); + constexpr auto MAccVgprs = + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp + .GetLength(I6); + + // LDS descriptor, shuffle and write out in MRepeat x NRepeat times + constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat = + GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat + .GetElementSpaceSize()); + + constexpr auto + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs = + transform_tensor_descriptor( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // MRepeat per shuffle repeat + MWave, // MWave + MSubGroup, // MSubGroup * MAccVgprs = MPerWmma + MAccVgprs)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // NRepeat per shuffle repeat + NWave, // NWave + NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<>{}, + Sequence<0, 1, 2, 6>{}, + Sequence<>{}, + Sequence<3, 4, 5>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + MRepeat, MWave, MSubGroup, MAccVgprs))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor + .CalculateBottomIndex(make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor = + make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple( + NRepeat, NWave, NThreadPerSubGroup))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor + .CalculateBottomIndex(make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6>, + 6, + 1, // vector write pixel + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + make_multi_index(0, + m_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + 0, + n_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; + + // shuffle: blockwise copy C from LDS to global + auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< + ThisThreadBlock, // ThreadGroup + CElementwiseOperation, // ElementwiseOperation, + CGlobalMemoryDataOperation, // DstInMemOp, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + CShuffleDataType, // typename SrcData, + CDataType, // typename DstData, + decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat), + decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), + Sequence<0, 1, 2, 3>, // typename DimAccessOrder, + 3, // index_t VectorDim, + CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector, + true, // bool ThreadTransferSrcResetCoordinateAfterRun, + false> // bool ThreadTransferDstResetCoordinateAfterRun> + {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + make_multi_index(0, 0, 0, 0), + c_grid_desc_mblock_mperblock_nblock_nperblock, + make_multi_index(block_m_id, 0, block_n_id, 0), + c_element_op}; + + // space filling curve for local reg & global memory + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6>, + Sequence>{}; + + // space filling curve for shuffled blockwise C in global mem + constexpr auto sfc_c_global = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMRepeatPerShuffle * MWave * MPerWmma, + 1, + CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!"); + + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run( + c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + c_shuffle_block_copy_lds_to_global.Run( + c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat, + c_shuffle_block_buf, + c_grid_desc_mblock_mperblock_nblock_nperblock, + c_grid_buf); + + if constexpr(access_id < num_access - 1) + { + constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id); + + // move on C + c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow( + c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step); + } + }); + } + } +}; + +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp index 93eed31bc5..6543e3df23 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_b_scale.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -17,6 +17,22 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_FP16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_WMMA +void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances); +#endif +#ifdef CK_USE_XDL void add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( std::vector>>& instances); #endif +#endif template && is_same_v && is_same_v) { +#ifdef CK_USE_WMMA + add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); +#endif +#ifdef CK_USE_XDL add_device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instances(op_ptrs); +#endif } } diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt index 424320fa8f..34f51f5f58 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/CMakeLists.txt @@ -1,10 +1,12 @@ -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_B_SCALE_INSTANCES) list(APPEND GEMM_B_SCALE_INSTANCES device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp + device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp ) set_source_files_properties(device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") +set_source_files_properties(device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") -add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) \ No newline at end of file +add_instance_library(device_gemm_b_scale_instance ${GEMM_B_SCALE_INSTANCES}) diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp new file mode 100644 index 0000000000..9476eb6bf0 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp @@ -0,0 +1,72 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using I4 = pk_i4_t; +using F16 = half_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave; + +template +using device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| CLayout|AData| BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| Compute| Compute| PermuteA| PermuteB| + //################################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock| ScalarPerVector| Pipeline| Pipeline| TypeA| TypeB| | | + //################################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock| Scheduler| Verision| | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //1 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 8, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //2 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //3 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //4 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 64, 64, 64, 8, 8, 16, 16, 2, 2, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 4>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //5 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //6 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //7 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 32, 32, 64, 8, 8, 16, 16, 2, 2, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //8 + + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //9 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //10 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 8, 16, 16, 4, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //11 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //12 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //13 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 32, 32, 128, 8, 8, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 2, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //14 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //15 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //16 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 8, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //17 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //18 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Intrawave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false>, //19 + DeviceGemm_BScale_Wmma_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 32, 1, 128, 16, 16, 128, 8, 8, 16, 16, 1, 1, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 4, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, 1, 1, S<1, 16, 1, 2>, 4, Interwave, BlockGemmPipelineVersion::v1, half_t, half_t, false, false> //20 + + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp new file mode 100644 index 0000000000..9c196a3c58 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_wmma_f16_i4_f16/device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instance.cpp @@ -0,0 +1,31 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { +void add_device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_v2_default_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, device_gemm_b_scale_wmma_f16_i4_f16_mk_nk_mn_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp index ce5cf21a85..1f8ca4d23a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_b_scale/device_gemm_b_scale_xdl_f16_i4_f16/device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" @@ -46,7 +46,7 @@ using device_gemm_b_scale_xdl_f16_i4_f16_mk_nk_mn_mem_instances = std::tuple< //#########################| | | | Type| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#########################| | | | | | Type| | | | Operation| Operation| Operation| | | N| K| | | | | |Wave| Wave| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - + //Compute friendly DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 8, 32, 32, 32, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3, half_t, half_t, false, false>, //0 DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F16, I4, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 8, 32, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, 1, 1, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v4, half_t, half_t, false, false>, //1 diff --git a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp index fe977e766e..86370e2f47 100644 --- a/profiler/include/profiler/profile_gemm_b_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_b_scale_impl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -173,7 +173,7 @@ bool profile_gemm_b_scale_impl(int do_verification, } } using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 256; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmBScale_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 768; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmBScale_MK_NK, Regular) +{ + std::vector Ms{512, 1024}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} diff --git a/test/gemm_b_scale/test_gemm_b_scale_util.hpp b/test/gemm_b_scale/test_gemm_b_scale_util.hpp new file mode 100644 index 0000000000..ec47470b84 --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_util.hpp @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_b_scale_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmBScale : public testing::Test +{ + using Row = ck::tensor_layout::gemm::RowMajor; + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using BScaleDataType = std::tuple_element_t<4, Tuple>; + using ComputeDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + + public: + static constexpr ck::index_t ScaleBlockK = 128; // all instances + static constexpr bool verify_ = true; + static constexpr int init_method_ = 2; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideC, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideC, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_b_scale_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp b/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp new file mode 100644 index 0000000000..38a3540925 --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_wmma.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_b_scale_util.hpp" + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBScale_MK_NK + : public ck::test::TestGemmBScale, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType + std::tuple< F16, I4, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_b_scale_ut_cases.inc" diff --git a/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp b/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp new file mode 100644 index 0000000000..38a3540925 --- /dev/null +++ b/test/gemm_b_scale/test_gemm_b_scale_xdl.cpp @@ -0,0 +1,45 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_b_scale_util.hpp" + +using I4 = ck::pk_i4_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBScale_MK_NK + : public ck::test::TestGemmBScale, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< + // ADataType, BDataType, BScaleDataType, ComputeDataType, CDataType + std::tuple< F16, I4, F16, F16, F16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBScale_MK_NK, KernelTypes_MK_NK); + +#include "test_gemm_b_scale_ut_cases.inc"