diff --git a/CHANGELOG.md b/CHANGELOG.md index 38669385f3..6dd06195c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data @@ -15,6 +16,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for GKCYX layout for grouped convolution backward data (NGCHW/GKCYX/NGKHW). * Added support for Stream-K version of mixed fp8/bf16 GEMM * Added support for Multiple D GEMM +* Added support for Multiple ABD GEMM * Added GEMM pipeline for microscaling (MX) FP8/FP6/FP4 data types * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. diff --git a/Jenkinsfile b/Jenkinsfile index 9d1af7c5d9..efe08a7d41 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -157,9 +157,9 @@ def getDockerImage(Map conf=[:]){ image = getDockerImageName() echo "Using default docker: ${image}" } - //Check if image exists + //Check if image exists def retimage - try + try { echo "Pulling image: ${image}" retimage = docker.image("${image}") @@ -232,7 +232,7 @@ def cmake_build(Map conf=[:]){ def setup_args = conf.get("setup_args","") // make sure all unit tests always run on develop branch def runAllUnitTests = (env.BRANCH_NAME == "develop") ? true : params.RUN_ALL_UNIT_TESTS - + if (prefixpath != "/usr/local"){ setup_args = setup_args + " -DCMAKE_PREFIX_PATH=${prefixpath} " } @@ -357,7 +357,7 @@ def cmake_build(Map conf=[:]){ "build_cmd", "${build_envs} ninja -j${nt} ${config_targets}" ) - + cmd = conf.get("cmd", """ ${setup_cmd} ${build_cmd} @@ -449,7 +449,7 @@ def buildHipClangJob(Map conf=[:]){ checkout scm def prefixpath = conf.get("prefixpath", "/opt/rocm") - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts if ( params.BUILD_INSTANCES_ONLY ){ dockerOpts = "--group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" @@ -515,7 +515,7 @@ def Build_CK(Map conf=[:]){ checkout scm def prefixpath = conf.get("prefixpath", "/opt/rocm") - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts="--device=/dev/kfd --device=/dev/dri --group-add video --group-add render --cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " @@ -719,7 +719,7 @@ def process_results(Map conf=[:]){ def image = "${env.CK_DOCKERHUB}:ck_ub22.04_rocm6.3" def prefixpath = "/opt/rocm" - // Jenkins is complaining about the render group + // Jenkins is complaining about the render group def dockerOpts="--cap-add=SYS_PTRACE --security-opt seccomp=unconfined" if (conf.get("enforce_xnack_on", false)) { dockerOpts = dockerOpts + " --env HSA_XNACK=1 " @@ -956,20 +956,20 @@ pipeline { defaultValue: '', description: 'If you want to use a custom docker image, please specify it here (default: leave blank).') string( - name: 'ROCMVERSION', + name: 'ROCMVERSION', defaultValue: '6.4.1', description: 'Specify which ROCM version to use: 6.4.1 (default).') string( - name: 'COMPILER_VERSION', - defaultValue: '', + name: 'COMPILER_VERSION', + defaultValue: '', description: 'Specify which version of compiler to use: release, amd-staging, amd-mainline, or leave blank (default).') string( - name: 'COMPILER_COMMIT', - defaultValue: '', + name: 'COMPILER_COMMIT', + defaultValue: '', description: 'Specify which commit of compiler branch to use: leave blank to use the latest commit (default), or use some specific commit of llvm-project branch.') string( - name: 'BUILD_COMPILER', - defaultValue: '/opt/rocm/llvm/bin/clang++', + name: 'BUILD_COMPILER', + defaultValue: '/opt/rocm/llvm/bin/clang++', description: 'Build CK with /opt/rocm/bin/hipcc, /llvm-project/build/bin/clang++, or with /opt/rocm/llvm/bin/clang++ (default).') booleanParam( name: "RUN_FULL_QA", @@ -1448,6 +1448,36 @@ pipeline { cleanWs() } } + stage("Run TILE_ENGINE_GEMM Tests on gfx1201") + { + when { + beforeAgent true + expression { params.RUN_TILE_ENGINE_GEMM_TESTS.toBoolean() } + } + agent{ label rocmnode("gfx1201") } + environment{ + setup_args = "NO_CK_BUILD" + execute_args = """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ + -D CMAKE_CXX_COMPILER="${build_compiler()}" \ + -D CMAKE_BUILD_TYPE=Release \ + -D GPU_TARGETS="gfx1201" \ + -D GEMM_DATATYPE="fp16" \ + -D GEMM_LAYOUT="rcr;rrr;crr;ccr" \ + -DGEMM_CONFIG_FILE=gfx120x_config.json \ + -DCMAKE_CXX_FLAGS=" -O3 " .. && \ + ninja -j64 benchmark_gemm_all && \ + python3 ../tile_engine/ops/gemm/gemm_benchmark.py . --problem-sizes "1024,1024,1024" \ + --warmup 5 --repeat 5 --verbose --json results.json && \ + ninja -j64 benchmark_gemm_fp16_rcr && \ + ninja -j64 benchmark_gemm_fp16_rrr && \ + ninja -j64 benchmark_gemm_fp16_crr && \ + ninja -j64 benchmark_gemm_fp16_ccr """ + } + steps{ + buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args) + cleanWs() + } + } } } @@ -1591,7 +1621,7 @@ pipeline { agent{ label rocmnode("gfx942") } steps{ script { - def execute_args = params.NINJA_FTIME_TRACE ? + def execute_args = params.NINJA_FTIME_TRACE ? """ cmake -G Ninja -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ @@ -1600,7 +1630,7 @@ pipeline { -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ -D CMAKE_CXX_FLAGS=" -O3 " .. && ninja -j64 """ - + buildHipClangJobAndReboot(setup_cmd: "", build_cmd: "", no_reboot:true, build_type: 'Release', execute_cmd: execute_args, docker_name: "${env.CK_DOCKERHUB_PRIVATE}:ck_ub24.04_rocm7.0") } cleanWs() diff --git a/example/01_gemm/CMakeLists.txt b/example/01_gemm/CMakeLists.txt index 2f9d85d51d..03bde86421 100644 --- a/example/01_gemm/CMakeLists.txt +++ b/example/01_gemm/CMakeLists.txt @@ -105,6 +105,16 @@ foreach(gpu IN LISTS GPU_TARGETS) endif() endforeach() +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_gemm_xdl_lds_direct_load_fp32_tf32 gemm_xdl_lds_direct_load_fp32_tf32.cpp) + add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32_tf32) + set(target 1) + endif() +endforeach() + add_example_executable(example_gemm_xdl_fp8 gemm_xdl_fp8.cpp) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) diff --git a/example/01_gemm/common.hpp b/example/01_gemm/common.hpp index 434f549443..e482953e46 100644 --- a/example/01_gemm/common.hpp +++ b/example/01_gemm/common.hpp @@ -310,10 +310,14 @@ bool parse_cmd_args(int argc, return true; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -351,10 +355,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } diff --git a/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp new file mode 100644 index 0000000000..9b92fad779 --- /dev/null +++ b/example/01_gemm/gemm_xdl_lds_direct_load_fp32_tf32.cpp @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "common.hpp" + +#define USING_DIRECT_LOADS 1 +#if USING_DIRECT_LOADS +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp" +#else +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp" +#endif + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using F32 = float; + +using ADataType = F32; +using BDataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using CDataType = F32; +using ComputeDataType = ck::tf32_t; + +using ALayout = Row; +using BLayout = Col; +using CLayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +#if USING_DIRECT_LOADS +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle_LdsDirectLoad +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| +// ######| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockLds| +// ######| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| LoopScheduler | pipeline ver | gemm type | +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| +// ######| XDL| XDL| Per| Per| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraM| ThreadCluster| SrcAccessOrder| SrcVectorDim| Scalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| | | PerVector| | Lengths_K0_N_K1| | | PerVector| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, + 8, 8, 32, 32, 2, 2, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, S<4, 8, 8>, S<1, 0, 2>, 2, 1, 1, + 1, 1, S<1, 8, 1, 8>, 4, ck::LoopScheduler::Default, ck::PipelineVersion::v4, ComputeDataType>; +// clang-format on +#else +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle +// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 4>; +// clang-format on +#endif +using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + +using ReferenceGemmInstanceGPU = ck::tensor_operation::device::ReferenceGemm; + +#include "run_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/01_gemm/run_gemm_example.inc b/example/01_gemm/run_gemm_example.inc index 3e018aad1e..08e2b8c15f 100644 --- a/example/01_gemm/run_gemm_example.inc +++ b/example/01_gemm/run_gemm_example.inc @@ -4,6 +4,11 @@ #pragma once #include "ck/library/utility/validation_common.hpp" +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + template bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) { @@ -218,8 +223,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_host_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); #endif } @@ -249,8 +254,8 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config) pass &= ck::utils::check_err(c_m_n_device_result, c_m_n_device_ref_result, "Error: Incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return pass == true; diff --git a/example/09_convnd_fwd/CMakeLists.txt b/example/09_convnd_fwd/CMakeLists.txt index 91c072aef7..4f174bfcbb 100644 --- a/example/09_convnd_fwd/CMakeLists.txt +++ b/example/09_convnd_fwd/CMakeLists.txt @@ -19,4 +19,13 @@ foreach(gpu IN LISTS GPU_TARGETS) add_example_executable(example_convnd_fwd_xdl_fp64 convnd_fwd_xdl_fp64.cpp) set(target 1) endif() -endforeach() \ No newline at end of file +endforeach() + +list(APPEND gpu_list_tf32 gfx942) +set(target 0) +foreach(gpu IN LISTS GPU_TARGETS) + if(gpu IN_LIST gpu_list_tf32 AND target EQUAL 0) + add_example_executable(example_convnd_fwd_xdl_fp32_tf32 convnd_fwd_xdl_fp32_tf32.cpp) + set(target 1) + endif() +endforeach() diff --git a/example/09_convnd_fwd/convnd_fwd_common.hpp b/example/09_convnd_fwd/convnd_fwd_common.hpp index b0fd6a382a..d82b56ec00 100644 --- a/example/09_convnd_fwd/convnd_fwd_common.hpp +++ b/example/09_convnd_fwd/convnd_fwd_common.hpp @@ -27,10 +27,14 @@ void print_helper_msg() << ck::utils::conv::get_conv_param_parser_helper_msg() << std::endl; } -template +template inline __host__ __device__ constexpr double get_rtol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 5e-3; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -68,10 +72,14 @@ inline __host__ __device__ constexpr double get_rtol() } } -template +template inline __host__ __device__ constexpr double get_atol() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-2; + } + else if constexpr(std::is_same_v) { return 1e-3; } @@ -116,7 +124,8 @@ template + typename DeviceConvNDFwdInstance, + typename ComputeDataType = OutDataType> bool run_grouped_conv_fwd(bool do_verification, int init_method, bool time_kernel, @@ -228,7 +237,11 @@ bool run_grouped_conv_fwd(bool do_verification, OutDataType, InElementOp, WeiElementOp, - OutElementOp>(); + OutElementOp, + 0, + 0, + 0, + ComputeDataType>(); auto ref_invoker = ref_conv.MakeInvoker(); auto ref_argument = ref_conv.MakeArgument(in, @@ -249,8 +262,8 @@ bool run_grouped_conv_fwd(bool do_verification, return ck::utils::check_err(out_device, out_host, "Error: incorrect results!", - get_rtol(), - get_atol()); + get_rtol(), + get_atol()); } return true; diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp new file mode 100644 index 0000000000..348da7e1ef --- /dev/null +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp32_tf32.cpp @@ -0,0 +1,89 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "convnd_fwd_common.hpp" + +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp" + +#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" + +#define EXAMPLE_WITH_COMPUTE_DATATYPE + +using InDataType = float; +using WeiDataType = float; +using AccDataType = float; +using CShuffleDataType = float; +using OutDataType = float; +using ComputeDataType = ck::tf32_t; + +template +using S = ck::Sequence; + +using InElementOp = ck::tensor_operation::element_wise::PassThrough; +using WeiElementOp = ck::tensor_operation::element_wise::PassThrough; +using OutElementOp = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvSpec = + ck::tensor_operation::device::ConvolutionForwardSpecialization::Default; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +template +using DeviceGroupedConvNDFwdInstance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< + NDimSpatial, + InLayout, // ALayout + WeiLayout, // BLayout + ck::Tuple<>, // DsLayout + OutLayout, // ELayout + InDataType, // ADataType + WeiDataType, // BDataType + AccDataType, // AccDataType + CShuffleDataType, // CShuffleDataType + ck::Tuple<>, // DsDataType + OutDataType, // EDataType + InElementOp, // AElementwiseOperation + WeiElementOp, // BElementwiseOperation + OutElementOp, // CDEElementwiseOperation + ConvSpec, // ConvForwardSpecialization + GemmSpec, // GemmSpecialization + 1, // NumGemmKPrefetchStage + 256, // BlockSize + 128, // MPerBlock + 192, // NPerBlock + 16, // KPerBlock + 4, // AK1 + 4, // BK1 + 32, // MPerXdl + 32, // NPerXdl + 2, // MXdlPerWave + 3, // NXdlPerWave + S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 + S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // ABlockTransferSrcAccessOrder + 2, // ABlockTransferSrcVectorDim + 4, // ABlockTransferSrcScalarPerVector + 4, // ABlockTransferDstScalarPerVector_AK1 + 1, // ABlockLdsExtraM + S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 + S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder + S<1, 0, 2>, // BBlockTransferSrcAccessOrder + 2, // BBlockTransferSrcVectorDim + 4, // BBlockTransferSrcScalarPerVector + 4, // BBlockTransferDstScalarPerVector_BK1 + 1, // BBlockLdsExtraN + 1, // CShuffleMXdlPerWavePerShuffle + 1, // CShuffleNXdlPerWavePerShuffle + S<1, 16, 1, 16>, // CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock + 4, // CDEBlockTransferScalarPerVector_NPerBlock + ComputeDataType, // AComputeDataType + ComputeDataType, // BComputeDataType + ck::LoopScheduler::Default, // LoopScheduler + 1 // NumGroupsToMerge + >; + +#include "run_convnd_fwd_example.inc" + +int main(int argc, char* argv[]) { return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp index fde0f51bc7..c635d01d8f 100644 --- a/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp +++ b/example/09_convnd_fwd/convnd_fwd_xdl_fp8.cpp @@ -7,6 +7,8 @@ #include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp" +#define EXAMPLE_WITH_COMPUTE_DATATYPE + using InDataType = ck::f8_t; using WeiDataType = ck::f8_t; using AccDataType = float; @@ -87,3 +89,5 @@ int main(int argc, char* argv[]) } return run_convnd_fwd_example(argc, argv) ? 0 : 1; } + +#undef EXAMPLE_WITH_COMPUTE_DATATYPE diff --git a/example/09_convnd_fwd/run_convnd_fwd_example.inc b/example/09_convnd_fwd/run_convnd_fwd_example.inc index 49852ff667..016a189d4b 100644 --- a/example/09_convnd_fwd/run_convnd_fwd_example.inc +++ b/example/09_convnd_fwd/run_convnd_fwd_example.inc @@ -3,6 +3,11 @@ #pragma once +// use macro to minimize code change +#ifndef EXAMPLE_WITH_COMPUTE_DATATYPE +using ComputeDataType = AccDataType; +#endif + bool run_convnd_fwd_example(int argc, char* argv[]) { print_helper_msg(); @@ -65,17 +70,17 @@ bool run_convnd_fwd_example(int argc, char* argv[]) InElementOp, WeiElementOp, OutElementOp, - DeviceGroupedConvNDFwdInstance>( - do_verification, - init_method, - time_kernel, - conv_param, - in_g_n_c_wis_desc, - wei_g_k_c_xs_desc, - out_g_n_k_wos_desc, - in_element_op, - wei_element_op, - out_element_op); + DeviceGroupedConvNDFwdInstance, + ComputeDataType>(do_verification, + init_method, + time_kernel, + conv_param, + in_g_n_c_wis_desc, + wei_g_k_c_xs_desc, + out_g_n_k_wos_desc, + in_element_op, + wei_element_op, + out_element_op); }; namespace ctc = ck::tensor_layout::convolution; diff --git a/example/14_gemm_quantization/CMakeLists.txt b/example/14_gemm_quantization/CMakeLists.txt index 8703fa3ed7..b058e7b0fa 100644 --- a/example/14_gemm_quantization/CMakeLists.txt +++ b/example/14_gemm_quantization/CMakeLists.txt @@ -1,3 +1,4 @@ add_example_executable(example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp) +add_example_executable(example_gemm_wmma_quantization_int8 gemm_wmma_quantization_int8.cpp) add_example_executable(example_gemm_xdl_bias_relu_quantization_int8 gemm_xdl_bias_relu_quantization_int8.cpp) add_example_executable(example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp) diff --git a/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp new file mode 100644 index 0000000000..a3023997a1 --- /dev/null +++ b/example/14_gemm_quantization/gemm_wmma_quantization_int8.cpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using ActivationOp = PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using DsDataType = ck::Tuple<>; +using EDataType = I8; + +using ALayout = Col; +using BLayout = Row; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; + +// clang-format off +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmMultipleD_Wmma_CShuffleV3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + EDataType, + AccDataType, + CShuffleDataType, + ActivationOp, + ActivationOp, + CDEElementOp, + GemmDefault, + 256, + 128, + 128, + 64, + 8, + 8, + 16, + 16, + 4, + 2, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + S<4, 64, 1>, + S<0, 2, 1>, + S<0, 2, 1>, + 1, + 1, + 8, + true, + 1, + 1, + S<1, 32, 1, 8>, + S<1>, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + I8, + I8>; +// clang-format on + +using ReferenceGemmInstance = ck::tensor_operation::host:: + ReferenceGemm; + +int main(int /* argc */, char* /* argv */[]) +{ + bool do_verification = true; + bool time_kernel = false; + + // GEMM shape + ck::index_t M = 1024; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = N; + ck::index_t StrideE = N; + + float requant_scale = 0.03; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + auto a_element_op = PassThrough{}; + auto b_element_op = PassThrough{}; + auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + // device GEMM + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + auto argument = gemm.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), + static_cast(b_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + float ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << gemm.GetTypeString() << std::endl; + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + if(do_verification) + { + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument( + a_m_k, b_k_n, e_m_n_host_result, a_element_op, b_element_op, cde_element_op); + + ref_invoker.Run(ref_argument); + + return ck::utils::check_err(e_m_n_device_result, e_m_n_host_result) ? 0 : 1; + } + + return 0; +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp index 5bdc993192..2fcc0e3cb1 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp @@ -323,6 +323,31 @@ int main(int argc, char* argv[]) problem_size.Ms = {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (>0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ns.push_back(768); @@ -333,21 +358,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (>0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 6806bd1886..fb611fd444 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -296,6 +296,32 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(128 + rand() % 128); @@ -307,21 +333,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (> 0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp index 8418c10f5e..47eb6637bd 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16_fp8.cpp @@ -297,6 +297,31 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -308,21 +333,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.k_batch = std::stoi(argv[4]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - printf("arg4: k_batch (> 0)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp index 9f8f6cb1e4..16d018936b 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_splitk_fp16.cpp @@ -66,6 +66,28 @@ int main(int argc, char* argv[]) problem_size.group_count = 16; + if(argc == 4) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + } + else if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + problem_size.group_count = std::stoi(argv[4]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: group count (default=16)"); + exit(0); + } + for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); @@ -77,19 +99,5 @@ int main(int argc, char* argv[]) problem_size.stride_Cs.push_back(problem_size.Ns[i]); } - if(argc == 4) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - } - else - { - printf("arg1: verification (0=no, 1=yes)\n"); - printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); - printf("arg3: time kernel (0=n0, 1=yes)\n"); - exit(0); - } - return !run_grouped_gemm(problem_size, config); } diff --git a/example/20_grouped_conv_bwd_weight/common.hpp b/example/20_grouped_conv_bwd_weight/common.hpp index e0034bf7eb..9159e51eaf 100644 --- a/example/20_grouped_conv_bwd_weight/common.hpp +++ b/example/20_grouped_conv_bwd_weight/common.hpp @@ -123,7 +123,9 @@ inline bool parse_cmd_args(int argc, const ck::index_t num_dim_spatial = std::stoi(argv[4]); conv_param = ck::utils::conv::parse_conv_param( - num_dim_spatial, threshold_to_catch_partial_args, argv); + num_dim_spatial, + threshold_to_catch_partial_args + 1, // +1 because we already parsed num_dim_spatial + argv); } else { diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index 68db468a7c..3d79f2f6d3 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -213,8 +213,20 @@ list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS -Wno-undefined-func-template --save-temps ) -target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS) +check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32) +if(HAS_DISABLE_PACKED_FP32) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS + -mllvm --amdgpu-disable-packed-fp32=1 + ) + list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS + -DCK_TILE_DISABLE_PACKED_FP32=1 + ) +endif() + +target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS}) +target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated # by cmake will print too many files, execvp: /bin/sh: Argument list too long # however, this property may affect global diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index cb6cd44f64..7f55d7412f 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -131,4 +131,4 @@ TBD ## FP8 experimental support As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+. -Currently we only support `-vlayout=c`( `hdim*seqlen` for V matrix) and `-squant=1`(static quantization) with `hdim=128` for fp8 now. Full feature support will come later. +Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later. diff --git a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py index 42a9d5148a..802c9e51d7 100644 --- a/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py +++ b/example/ck_tile/01_fmha/codegen/cpp_symbol_map.py @@ -7,7 +7,8 @@ FWD_DTYPE_MAP = { "bf16" : "FmhaFwdBf16", "fp8" : "FmhaFwdFp8", "fp8fp16": "FmhaFwdFp8Fp16", - "fp8bf16": "FmhaFwdFp8Bf16" + "fp8bf16": "FmhaFwdFp8Bf16", + "fp8fp32": "FmhaFwdFp8Fp32" } BWD_DTYPE_MAP = { diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index d9452206e7..cfb96b7d53 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -163,7 +163,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& [[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{ return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0); }}; - + const bool has_load_tr = ck_tile::is_load_tr_supported(); {F_dispatch} @@ -248,11 +248,11 @@ class FmhaFwdApiTrait: if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly) else : return f'a.seqlen_q % {self.bm0} == 0' else: assert False - + @property def seqtune(self) -> str: if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true - else: + else: return f'a.seqlen_q <= {self.bm0}' @property @@ -351,7 +351,7 @@ class FmhaFwdPipeline: if self.F_squant == 't' : n += '_squant' else: n += '_nsquant' - + if self.F_trload == 't' : n += '_trload' else: n += '_ntrload' @@ -378,7 +378,7 @@ class FmhaFwdApiPool: "t": "has_load_tr", "f": "true" } - + per_tr_load =str() for tr_load in ["t", "f"]: per_dtypes=str() @@ -550,12 +550,16 @@ class KernelComponentFactory: (192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)], } - elif dtype == 'fp8' or dtype == 'bf8': + elif dtype == 'fp8' or dtype == 'fp8bf16': return { (64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], (256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], } + elif dtype == 'fp8fp32': + return { + (128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)], + } else: return None @@ -567,9 +571,9 @@ class KernelComponentFactory: # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! # TODO: how to design this more generic? - squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: + squant = 'f' for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): if hdim == 256 and hdim_v == 256: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f')) @@ -589,11 +593,12 @@ class KernelComponentFactory: pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't')) if receipt == 1 and bias != "bias": pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim - elif dtype in ['fp8', 'bf8']: + elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']: # no need lse/dropout kernels - for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): - pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) - elif dtype in ['fp8fp16', 'fp8bf16']: + for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f')) + elif dtype in ['fp8fp16', 'bf8']: # TODO None else: @@ -674,25 +679,34 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl continue # Aiter(mha_fwd) integration elif receipt == 100: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'batch' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # Aiter(mha_varlen_fwd) integration elif receipt == 200: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= mode == 'group' cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 if not cond: continue # aiter::mha_fwd C++ api integration elif receipt == 600: - cond = dtype in ['fp16', 'bf16'] + cond = dtype in ['fp16', 'bf16', 'fp8bf16'] cond &= pipeline.F_vlayout == 'row' - cond &= pipeline.F_squant == 'f' + if dtype == 'fp8bf16': + cond &= hdim == 128 + if not cond: + continue + elif receipt == 888: + cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32'] + cond &= pipeline.F_vlayout == 'row' + cond &= hdim == 128 if not cond: continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3b48b3d005..cee1505486 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -645,7 +645,6 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: return { '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py index 7b93e9654c..df6b422981 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_pagedkv_prefill.py @@ -465,14 +465,14 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl squant = 't' if dtype == 'fp8' else 'f' pipelines = [] if dtype in ['fp16', 'bf16']: - for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]): - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) - pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'col', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) + for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]): pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip)) elif dtype in ['fp8', 'bf8']: - # TODO - None + # no need lse/dropout kernels + for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()): + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) + pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f')) elif dtype in ['fp8fp16', 'fp8bf16']: # TODO None diff --git a/example/ck_tile/01_fmha/example_fmha_fwd.cpp b/example/ck_tile/01_fmha/example_fmha_fwd.cpp index c3bbb7a558..91cb9f55be 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd.cpp @@ -44,21 +44,15 @@ auto create_args(int argc, char* argv[]) .insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim).\n" - "note when squant=1, this value will be modified by range_q/k") + "note when squant=1, this value will be modified") .insert("logits_soft_cap", "0", "attention logits soft capping value.") - .insert("range_q", "16", "per-tensor quantization range of q. used if squant=1.") - .insert("range_k", "16", "per-tensor quantization range of k. used if squant=1.") - .insert("range_v", "16", "per-tensor quantization range of v. used if squant=1.") - .insert("range_p", "1", "per-tensor quantization range of p [e^(s-m)]. used if squant=1.") - .insert("range_o", "16", "per-tensor quantization range of o (p*v). used if squant=1.") .insert("squant", "auto", "if using static quantization fusion or not. auto: fp8 will default use squant, " "other will not\n" "0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to " "P and O.\n" - "calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, " - "range_p, range_o") + "calculate scale_s, scale_p, scale_o auto") .insert("iperm", "1", "permute input\n" @@ -89,7 +83,7 @@ auto create_args(int argc, char* argv[]) "uf", "init method:\n ui or 0 - uniform random int\n ni - normalized random int" "\n uf or 1 - uniform random float\n nf - normalized random float" - "\n tf or 2 - trig float\n uf:q or ufq or 3 - fp8 quantization") + "\n tf or 2 - trig float\n") .insert("seed", "11939", "random seed used for initializing input tensors. 0 for " @@ -148,11 +142,6 @@ auto run(const ck_tile::ArgParser& arg_parser) uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); bool drop_prefs = arg_parser.get_bool("drop_prefs"); std::string mask_str = arg_parser.get_str("mask"); - float range_q = arg_parser.get_float("range_q"); - float range_k = arg_parser.get_float("range_k"); - float range_v = arg_parser.get_float("range_v"); - float range_p = arg_parser.get_float("range_p"); - float range_o = arg_parser.get_float("range_o"); bool is_rotary_interleaved = arg_parser.get_bool("rotary_interleaved"); ck_tile::index_t num_splits = arg_parser.get_int("num_splits"); std::string init_method = arg_parser.get_str("init"); @@ -201,11 +190,6 @@ auto run(const ck_tile::ArgParser& arg_parser) drop_offset, drop_prefs, mask_str, - range_q, - range_k, - range_v, - range_p, - range_o, squant, is_rotary_interleaved, num_splits, @@ -237,6 +221,14 @@ int main(int argc, char* argv[]) { return run(arg_parser) == fwd_result::success ? 0 : -2; } + else if(data_type == "fp8bf16") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } + else if(data_type == "fp8fp32") + { + return run(arg_parser) == fwd_result::success ? 0 : -2; + } std::cerr << "Unsupported precision: " << data_type << std::endl; return -1; } diff --git a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp index d2428e5152..569c98a458 100644 --- a/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp +++ b/example/ck_tile/01_fmha/example_fmha_fwd_v3.cpp @@ -45,18 +45,7 @@ auto parse_cmd_args(int argc, char* argv[]) -> std::pair(hdim)); - mask = mask_info::decode(args.get_str("mask"), seqlen_q, seqlen_k); + + const auto is_causal = args.get_bool("causal"); + if(is_causal) + { + mask = mask_info::decode("b:-1,0", seqlen_q, seqlen_k); + } + else + { + mask = mask_info::decode("0", seqlen_q, seqlen_k); + } input_layout = args.get_int("iperm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; output_layout = args.get_int("operm") == 1 ? TensorLayout::bhsd : TensorLayout::bshd; diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index df1e9e5699..c41e48e6aa 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -41,6 +41,10 @@ struct FmhaFwdFp8Bf16 { }; +struct FmhaFwdFp8Fp32 +{ +}; + template struct FmhaFwdTypeConfig; @@ -108,6 +112,38 @@ struct FmhaFwdTypeConfig using ODataType = ck_tile::bf8_t; }; +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = ck_tile::bf16_t; +}; + +template <> +struct FmhaFwdTypeConfig +{ + using QDataType = ck_tile::fp8_t; + using KDataType = ck_tile::fp8_t; + using VDataType = ck_tile::fp8_t; + using BiasDataType = float; + using RandValOutputDataType = uint8_t; + using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j)) + using SaccDataType = float; // data type for first gemm accumulation + using SMPLComputeDataType = float; // data type for reduction, softmax + using PDataType = ck_tile::fp8_t; // data type for A matrix of second gemm + using OaccDataType = float; // data type for second gemm accumulation + using ODataType = float; +}; + struct FmhaMasks { using NoMask = ck_tile::GenericAttentionMask; diff --git a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp index 397245ab32..43f484fe14 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_runner.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_runner.hpp @@ -50,20 +50,30 @@ auto get_elimit(std::string /*init_method*/) } template <> -auto get_elimit(std::string init_method) +auto get_elimit(std::string /*init_method*/) { - if(init_method == "ui" || init_method == "ni") - { - unsigned max_rounding_point_distance = 0; - double atol = 2e-3; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } - else - { - unsigned max_rounding_point_distance = 1; - double atol = 0.0625; - return ck_tile::make_tuple(max_rounding_point_distance, atol); - } + using TypeConfig = FmhaFwdTypeConfig; + using ODataType = typename TypeConfig::ODataType; + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + double rtol = 0; + double atol = 16 * (o_dtype_max > 240 ? 2 : 1); + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1.8e-1; + return ck_tile::make_tuple(rtol, atol); } int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) @@ -157,11 +167,6 @@ fwd_result fmha_fwd_run(mode_enum mode, uint64_t drop_offset, bool drop_prefs, std::string mask_str, - float range_q, - float range_k, - float range_v, - float range_p, - float range_o, bool squant, bool is_rotary_interleaved, ck_tile::index_t num_splits, @@ -180,6 +185,10 @@ fwd_result fmha_fwd_run(mode_enum mode, return "fp8"; else if constexpr(std::is_same_v) return "bf8"; + else if constexpr(std::is_same_v) + return "fp8bf16"; + else if constexpr(std::is_same_v) + return "fp8fp32"; else static_assert(false); }(); @@ -367,22 +376,6 @@ fwd_result fmha_fwd_run(mode_enum mode, using OaccDataType = typename TypeConfig::OaccDataType; using ODataType = typename TypeConfig::ODataType; - float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - float p_dtype_max = v_dtype_max; // assume p and v is the same type - float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); - - float scale_p = 1.f; - float scale_o = 1.f; - - if(squant) - { - scale_s = scale_s * (range_q / q_dtype_max) * (range_k / k_dtype_max); - scale_p = p_dtype_max / range_p; - scale_o = (o_dtype_max / range_o) * (range_p / p_dtype_max) * (range_v / v_dtype_max); - } - // accumulation numbers for performance evaluation std::size_t flop = 0, num_byte = 0; auto max_seqlen_q = @@ -528,7 +521,7 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::HostTensor cache_batch_idx_host(use_cache_batch_idx ? std::array{batch} : std::array{1}); - + float max_o = 5.0; if(init_method == "ui" || init_method == "0") { ck_tile::FillUniformDistributionIntegerValue{-3.f, 3.f, next_seed()}(q_host); @@ -576,32 +569,6 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::FillTrigValue{}(vnew_host); ck_tile::FillTrigValue{}(bias_host); } - else if(init_method == "ufq" || init_method == "uf:q" || init_method == "3") - { - // suitable for fp8 quantization - if(!squant) - { - std::cerr << "init method " << init_method << " can not be used without quantization" - << std::endl; - return fwd_result::invalid_args; - } - ck_tile::FillUniformDistribution{0.f, q_dtype_max, next_seed()}(q_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(k_host); - ck_tile::FillUniformDistribution{0.f, k_dtype_max, next_seed()}(knew_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(v_host); - ck_tile::FillUniformDistribution{0.f, v_dtype_max, next_seed()}(vnew_host); - - // bias_fp8 = qscale_bias * bias_fp32 - float qscale_bias = (q_dtype_max / range_q) * (k_dtype_max / range_k); - // Assume bias is in [0.f, 1.f] in original fp32 - ck_tile::FillUniformDistribution{0.f, qscale_bias, next_seed()}(bias_host); - } - else - { - std::cerr << "Unknown value for init argument: " << init_method << std::endl; - return fwd_result::invalid_args; - } - if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -625,8 +592,8 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem knew_buf(knew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem vnew_buf(vnew_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes()); @@ -650,10 +617,79 @@ fwd_result fmha_fwd_run(mode_enum mode, ck_tile::DeviceMem block_table_buf(block_table_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem cache_batch_idx_buf(cache_batch_idx_host.get_element_space_size_in_bytes()); + float scale_p = 1.f; + float scale_o = 1.f; + if(squant) + { + float q_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float k_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float v_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + float p_dtype_max = v_dtype_max; // assume p and v is the same type + // Q tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = q_dtype_max / max_value; + + q_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // K tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + float scale = k_dtype_max / max_value; + k_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + scale_s = scale_s / scale; + } + + // V tensor + { + float max_value = ck_tile::type_convert(ck_tile::numeric::min()); + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + if(val > max_value) + max_value = val; + }); + + float scale = k_dtype_max / max_value; + v_host.ForEach([&](auto& self, auto idx) { + float val = ck_tile::type_convert(self(idx)); + self(idx) = ck_tile::type_convert(val * scale); + }); + + scale_o = (1.0 / p_dtype_max) / scale; + } + + scale_p = p_dtype_max; + + if constexpr(std::is_same_v) + { + float o_dtype_max = ck_tile::type_convert(ck_tile::numeric::max()); + scale_o = scale_o * o_dtype_max / max_o; + } + } + q_buf.ToDevice(q_host.data()); k_buf.ToDevice(k_host.data()); - knew_buf.ToDevice(knew_host.data()); v_buf.ToDevice(v_host.data()); + knew_buf.ToDevice(knew_host.data()); vnew_buf.ToDevice(vnew_host.data()); bias_buf.ToDevice(bias_host.data()); seqstart_q.ToDevice(seqstart_q_host.data()); @@ -1103,7 +1139,9 @@ fwd_result fmha_fwd_run(mode_enum mode, lse_buf.FromDevice(lse_host.data()); randval_buf.FromDevice(randval_host.data()); - constexpr bool supports_squant = std::is_same_v; + constexpr bool supports_squant = std::is_same_v || + std::is_same_v || + std::is_same_v; auto p_compute_element_func = [&]() { if constexpr(supports_squant) @@ -1113,9 +1151,11 @@ fwd_result fmha_fwd_run(mode_enum mode, }(); auto oacc_element_func = [&]() { - if constexpr(supports_squant) + if constexpr(std::is_same_v && supports_squant) return ck_tile::composes(ck_tile::saturates{}, ck_tile::scales{scale_o}); + else if constexpr(supports_squant) + return ck_tile::scales{scale_o}; else return ck_tile::identity{}; }(); diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp index 5361d27f0f..10cb5149a4 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3.hpp @@ -34,7 +34,8 @@ struct fmha_fwd_v3_args index_t window_size_left; index_t window_size_right; - index_t mask_type; + index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and + // window_size_right == 0). const void* q_ptr; index_t stride_q; diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index d6e4ac4c60..e0fbad39a5 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -18,6 +18,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "fmha_fwd_v3.hpp" +#include "mask.hpp" #define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ template <> \ @@ -79,7 +80,7 @@ struct fmha_fwd_v3_kernel_traits -1 // kBlockPerCu >; - using fmha_mask = SimplifiedGenericAttentionMask; + using fmha_mask = GenericAttentionMask; using fmha_pipeline_problem = BlockFmhaFwdV3PipelineProblem::qkvp_dtype, @@ -112,6 +113,22 @@ struct fmha_fwd_v3_kernel_traits template float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_config& config) { + /// NOTICE: This was borrowed from Aiter. Make sure the selected remap_opt setting truly + /// maximizes the kernel's performance. + int remap_opt = 2; + if(args.mask_type != static_cast(mask_enum::no_mask) && + ((args.nhead_q % 8 != 0) || (16384 < args.seqlen_q))) + { + if(65536 <= args.seqlen_q) + { + remap_opt = 0; + } + else + { + remap_opt = 1; + } + } + auto kargs = Kernel::MakeKargs(args.q_ptr, args.k_ptr, args.v_ptr, @@ -140,7 +157,8 @@ float fmha_fwd_v3_kernel_launch(const fmha_fwd_v3_args& args, const stream_confi args.batch_stride_o, args.window_size_left, args.window_size_right, - args.mask_type); + args.mask_type, + remap_opt); dim3 grids = Kernel::GridSize(args.batch, args.nhead_q, args.seqlen_q, args.hdim_v); constexpr dim3 blocks = Kernel::BlockSize(); diff --git a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh index 9c500edf9d..b847e85398 100755 --- a/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh +++ b/example/ck_tile/01_fmha/script/benchmark_fwd_v3.sh @@ -8,22 +8,16 @@ for prec in "fp16" "bf16" ; do for hdim in 128 ; do for perm in 0 ; do -if [ $causal -eq 0 ]; then - mask=0 -else - mask=b:-1,0 -fi - -$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID -$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -mask=$mask -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID +$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID done done diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt new file mode 100644 index 0000000000..ea601ec002 --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx90a.txt @@ -0,0 +1,2 @@ +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt new file mode 100644 index 0000000000..ea601ec002 --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx942.txt @@ -0,0 +1,2 @@ +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt new file mode 100644 index 0000000000..1497d491bb --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_bwd_known_fails_gfx950.txt @@ -0,0 +1,31 @@ +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=32 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=32 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=64 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=64 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=0 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=2 -h=2 -d=128 -s=516 -s_k=253 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 +tile_example_fmha_bwd -prec=bf16 -b=1 -h=4 -h_k=1 -d=128 -s=500 -s_k=251 -bias=a -dbias=0 -p_drop=0.0 -iperm=1 -operm=1 -mask=1 -deterministic=0 -v=1 -mode=1 -kname=1 -v=1 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx90a.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx942.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt new file mode 100644 index 0000000000..90c5e2b7fb --- /dev/null +++ b/example/ck_tile/01_fmha/script/fmha_fwd_known_fails_gfx950.txt @@ -0,0 +1,4 @@ +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=0 -operm=0 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=2 -h=1 -d=128 -d_v=24 -s=3 -s_k=99 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 +tile_example_fmha_fwd -prec=fp16 -mode=0 -b=1 -h=2 -h_k=1 -d=128 -s=1 -s_k=10 -s_kpad=32 -bias=n -p_drop=0.0 -lse=0 -iperm=1 -operm=1 -mask=2 -num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=1 -v=1 -warmup=0 -repeat=1 diff --git a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh index d123f842a2..3b59505ff0 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_bwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_bwd.sh @@ -2,13 +2,35 @@ # TODO: run this script from CK root or build directory set -euo pipefail -EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)" +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_bwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi export CK_WARMUP=0 export CK_REPEAT=1 +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"} + COMMON_ARGS='-v=1' + +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + set -x for prec in "fp16" "bf16" ; do for perm in 0 1 ; do @@ -19,12 +41,12 @@ for dbias in 0 ; do for p_drop in 0.0 0.2 ; do for deterministic in 0 ; do -$EXE -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS -$EXE -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=1 -h=4 -h_k=2 -d=$hdim -s=259 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=2 -h=2 -d=$hdim -s=516 -s_k=253 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=1 -h=4 -h_k=1 -d=$hdim -s=500 -s_k=251 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=1 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=1 -h=2 -d=$hdim -s=900 -s_k=258 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=2 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=2 -h=1 -d=$hdim -s=987 -s_k=219 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=t:128,30 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS +run_exe -prec=$prec -b=2 -h=3 -h_k=1 -d=$hdim -s=244 -s_k=499 -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -mask=b:4,35 -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS done done @@ -35,3 +57,24 @@ done done done set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 3913a0d5c2..afd0c728c6 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -2,12 +2,23 @@ # TODO: run this script from CK root or build directory set -euo pipefail -EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)" +SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd) +EXE_NAME=tile_example_fmha_fwd +EXE="$(find . -name $EXE_NAME -type f | head -n 1)" KNAME=1 +GPU_arch=$GPU_arch +if [ -z "$GPU_arch" ] ; then + GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}') +fi export CK_WARMUP=0 export CK_REPEAT=1 +CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"} +rm -f $CURR_FAILS_FILE +touch $CURR_FAILS_FILE +KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"} + COMMON_ARGS='-v=1 -warmup=0 -repeat=1' # mode=0 # export HIP_VISIBLE_DEVICES=4 @@ -30,6 +41,16 @@ while getopts ":sa" opt; do esac done +run_exe() { + set +ex + $EXE $@ + local ret=$? + if [ $ret -ne 0 ] ; then + echo "$EXE_NAME $*" >> $CURR_FAILS_FILE + fi + set -ex +} + run_fp16_bf16_tests() { local NUM_SPLITS="1" local PAGE_BLOCK_SIZE="0" @@ -52,16 +73,16 @@ run_fp16_bf16_tests() { for page_block_size in $PAGE_BLOCK_SIZE ; do for cache_batch_idx in $CACHE_BATCH_IDX ; do - # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS - $EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + # run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS + run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done ; done ; done @@ -73,7 +94,29 @@ run_fp8_tests() { for b in 1 2 ; do for hdim in 64 128 256 ; do - $EXE -prec=fp8 -init=3 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=c -squant=1 -kname=$KNAME $COMMON_ARGS + $EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8bf16_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS + + done ; done ; done ; done +} + +run_fp8fp32_tests() { + for perm in 0 1 ; do + for bias in "n" "e" "a" ; do + for b in 1 2 ; do + for hdim in 64 128 256 ; do + + $EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS done ; done ; done ; done } @@ -88,7 +131,7 @@ run_fp16_appendkv_tests() { for page_block_size in 0 128 ; do for cache_batch_idx in 0 1 ; do - $EXE -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS + run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS done ; done ; done ; done ; done done ; done ; done @@ -98,9 +141,32 @@ set -x run_fp16_bf16_tests run_fp8_tests +run_fp8bf16_tests +run_fp8fp32_tests if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests fi set +x + +new_fails_count=0 +known_fails_count=0 +if [ -f $KNOWN_FAILS_FILE ] ; then + echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):" + while IFS= read -r line; do + if grep -Fxq "$line" $KNOWN_FAILS_FILE; then + echo "Known fail: $line" + known_fails_count=$(($known_fails_count + 1)) + else + echo "New fail: $line" + new_fails_count=$(($new_fails_count + 1)) + fi + done < $CURR_FAILS_FILE +else + new_fails_count=$(wc -l < $CURR_FAILS_FILE) + echo "No known fails file, all fails ($new_fails_count) are new:" + cat $CURR_FAILS_FILE +fi +echo "New fails count: $new_fails_count; Known fails count: $known_fails_count" +exit $(($new_fails_count != 0)) diff --git a/example/ck_tile/17_grouped_gemm/CMakeLists.txt b/example/ck_tile/17_grouped_gemm/CMakeLists.txt index 8e8026d88d..4f3b173c55 100644 --- a/example/ck_tile/17_grouped_gemm/CMakeLists.txt +++ b/example/ck_tile/17_grouped_gemm/CMakeLists.txt @@ -1,3 +1,10 @@ add_executable(tile_example_grouped_gemm EXCLUDE_FROM_ALL grouped_gemm.cpp) add_executable(tile_example_quant_grouped_gemm EXCLUDE_FROM_ALL quant_grouped_gemm.cpp) add_executable(tile_example_grouped_gemm_preshuffle EXCLUDE_FROM_ALL grouped_gemm_preshuffle.cpp) +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() +target_compile_options(tile_example_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +target_compile_options(tile_example_quant_grouped_gemm PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 9975f2024b..606d98d9e2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -356,6 +356,8 @@ int main(int argc, char* argv[]) #if CK_TILE_USE_WMMA return !run_grouped_gemm_example(argc, argv); #else - return !run_grouped_gemm_example(argc, argv); + return !run_grouped_gemm_example(argc, argv) || + !run_grouped_gemm_example(argc, argv) || + !run_grouped_gemm_example(argc, argv); #endif } diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 1ae0844032..6493a542ba 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -91,7 +91,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; - static constexpr bool Persistent = false; + static constexpr bool Persistent = true; static constexpr bool DoubleSmemBuffer = false; }; @@ -139,6 +139,29 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr int kBlockPerCu = 2; }; +template +struct GemmConfigComputeV4_V2 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + + static constexpr int kBlockPerCu = 2; +}; + template struct GemmConfigPreshuffleDecode : public GemmConfigBase { diff --git a/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt new file mode 100644 index 0000000000..f382e0cf45 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/CMakeLists.txt @@ -0,0 +1 @@ +add_executable(tile_example_gemm_multi_abd_fp16 EXCLUDE_FROM_ALL gemm_multi_abd_fp16.cpp) diff --git a/example/ck_tile/22_gemm_multi_abd/README.md b/example/ck_tile/22_gemm_multi_abd/README.md new file mode 100644 index 0000000000..c272df3fb5 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/README.md @@ -0,0 +1,35 @@ +#Multiple ABD GEMM + +This folder contains example for Multiple ABD GEMM using ck_tile tile-programming implementation. + +## build +``` +#in the root of ck_tile +mkdir build && cd build +#you can replace < arch> with the appropriate architecture(for example gfx90a or gfx942) or \ + leave it blank +sh ../script/cmake-ck-dev.sh ../ +#The basic pipeline method on the gemm calculation +make tile_example_gemm_multi_abd_fp16 -j +``` +This will result in an executable `build/bin/tile_example_gemm_multi_abd_fp16` + +## example +``` +args: + -m M dimensions - (Default: 3840) + -n N dimensions - (Default: 4096) + -k K dimensions - (Default: 4096) +-as_layout Tensor A layout (default:R) +-bs_layout Tensor B layout (default:C) +-ds_layout Tensor D layout (default:R) +-e_layout Tensor E layout (default:R) +-stride_as Tensor A strides - (Default: 0) +-stride_bs Tensor B strides - (Default: 0) +-stride_e Tensor C strides - (Default: 0) +-stride_ds Tensor D strides - (Default: 0) +-validate 0. No validation, 1. Validation on GPU. (Default: 1) + -warmup Number of iterations before benchmark the kernel. (Default: 10) + -repeat Number of iterations to benchmark the kernel. (Default: 100) + -kbatch kbatch for SplitK. (Default: 1) +``` \ No newline at end of file diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp new file mode 100644 index 0000000000..6d955c3a09 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/host.hpp" +#include "gemm_multi_abd_fp16.hpp" +#include "utils.hpp" + +template +auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_config& s) -> float +{ + constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; + + constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; + + constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + + constexpr bool DoubleSmemBuffer = GemmConfig::DoubleSmemBuffer; + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " + << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " + << blocks.y << ", " << blocks.z << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + + return ave_time; +} + +#include "run_gemm_multi_abd_fp16_example.inc" + +int main(int argc, char* argv[]) +{ +#if CK_TILE_USE_WMMA + return !run_multiple_abd_gemm_example(argc, argv); +#else + return !run_multiple_abd_gemm_example(argc, argv); +#endif +} diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp new file mode 100644 index 0000000000..35bc232eca --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.hpp @@ -0,0 +1,186 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 + +#ifndef CK_TILE_PIPELINE_DEFAULT +#define CK_TILE_PIPELINE_DEFAULT CK_TILE_PIPELINE_COMPUTE_V3 +#endif + +using A0DataType = ck_tile::half_t; +using A1DataType = ck_tile::half_t; + +using B0DataType = ck_tile::half_t; +using B1DataType = ck_tile::half_t; + +using D0DataType = ck_tile::half_t; +using D1DataType = ck_tile::half_t; + +using EDataType = ck_tile::half_t; + +using AsDataType = ck_tile::tuple; +using BsDataType = ck_tile::tuple; +using DsDataType = ck_tile::tuple; + +using AccDataType = float; + +struct GemmConfigMemory +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 8; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +struct GemmConfigV3 +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV4 +{ + // Compute friendly for Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 32; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +struct GemmConfigV3_Wmma +{ + // Compute friendly for Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64; + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "4096", "k dimension") + .insert("as_layout", "R", "As tensor data layout - Row by default") + .insert("bs_layout", "C", "Bs tensor data layout - Col by default") + .insert("ds_layout", "R", "Ds tensor data layout - Row by default") + .insert("e_layout", "R", "E tensor data layout - Row by default") + .insert("stride_as", "0", "Tensor A stride") + .insert("stride_bs", "0", "Tensor B stride") + .insert("stride_ds", "0", "Tensor Ds stride") + .insert("stride_e", "0", "Tensor E stride") + .insert("v", "1", "0. No validation, 1. Validation on GPU") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("kbatch", "1", "kbatch for SplitK"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} +using gemm_multi_abd_kargs = + ck_tile::GemmMultiABDHostArgs; + +template +float gemm_multi_abd(const gemm_multi_abd_kargs& kargs, const ck_tile::stream_config& s); diff --git a/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc b/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc new file mode 100644 index 0000000000..881961c9db --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/run_gemm_multi_abd_fp16_example.inc @@ -0,0 +1,311 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include + +template +float invoke_gemm_multi_abd(const std::array& as_m_k_dev_buf, + const std::array& bs_k_n_dev_buf, + const std::array& ds_m_n_dev_buf, + void* e_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + const std::array& StrideAs, + const std::array& StrideBs, + const std::array& StrideDs, + ck_tile::index_t StrideE, + int n_warmup, + int n_repeat, + int k_batch) +{ + gemm_multi_abd_kargs gemm_descs({as_m_k_dev_buf, + bs_k_n_dev_buf, + ds_m_n_dev_buf, + e_m_n_dev_buf, + k_batch, + M, + N, + K, + StrideAs, + StrideBs, + StrideDs, + StrideE}); + + float ave_time = gemm_multi_abd( + gemm_descs, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + + std::string op_name{"Gemm Multiple-ABD"}; + + std::size_t flop = 0, num_btype = 0; + + flop += std::size_t(2) * M * N * K; + + num_btype += + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Run Gemm Multiple-ABD kernel with:\n"; + std::cout << "M =" << M << " N =" << N << " K =" << K << "\n"; + std::cout << "StrideA = " << StrideAs[0] << " StrideB = " << StrideBs[0] + << " StrideE = " << StrideE << "\n"; + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << "\n"; + + return ave_time; +} + +template +int run_gemm_multi_abd_example_with_layouts(int argc, + char* argv[], + const A0Layout a0_layout = A0Layout{}, + const A1Layout a1_layout = A1Layout{}, + const B0Layout b0_layout = B0Layout{}, + const B1Layout b1_layout = B1Layout{}, + const D0Layout d0_layout = D0Layout{}, + const D1Layout d1_layout = D1Layout{}, + const ELayout e_layout = ELayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + using AElementWiseFn = ck_tile::element_wise::AddScale; + using BElementWiseFn = ck_tile::element_wise::AddScale; + using CDEElementWiseFn = ck_tile::element_wise::MultiDMultiply; + using AsLayout = ck_tile::tuple; + using BsLayout = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t StrideA = arg_parser.get_int("stride_as"); + ck_tile::index_t StrideB = arg_parser.get_int("stride_bs"); + ck_tile::index_t StrideD = arg_parser.get_int("stride_ds"); + ck_tile::index_t StrideE = arg_parser.get_int("stride_e"); + + ck_tile::index_t StrideA0 = StrideA; + ck_tile::index_t StrideA1 = StrideA; + + ck_tile::index_t StrideB0 = StrideB; + ck_tile::index_t StrideB1 = StrideB; + + ck_tile::index_t StrideD0 = StrideD; + ck_tile::index_t StrideD1 = StrideD; + + const int n_warmup = arg_parser.get_int("warmup"); + const int n_repeat = arg_parser.get_int("repeat"); + const int k_batch = arg_parser.get_int("kbatch"); + + StrideA0 = get_default_stride(M, N, StrideA0, is_row_major(a1_layout)); + StrideA1 = get_default_stride(M, N, StrideA1, is_row_major(a1_layout)); + + StrideB0 = get_default_stride(K, N, StrideB0, is_row_major(b0_layout)); + StrideB1 = get_default_stride(K, N, StrideB1, is_row_major(b1_layout)); + + StrideD0 = get_default_stride(M, N, StrideD0, is_row_major(d0_layout)); + StrideD1 = get_default_stride(M, N, StrideD1, is_row_major(d1_layout)); + + StrideE = get_default_stride(M, N, StrideE, is_row_major(e_layout)); + + ck_tile::HostTensor a0_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout))); + ck_tile::HostTensor a1_m_k_tesnor( + host_tensor_descriptor(M, K, StrideA1, is_row_major(a1_layout))); + + ck_tile::HostTensor b0_k_n_tensors( + host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout))); + ck_tile::HostTensor b1_k_n_tensors( + host_tensor_descriptor(K, N, StrideB1, is_row_major(b1_layout))); + + ck_tile::HostTensor d0_m_n_tensors( + host_tensor_descriptor(M, N, StrideD0, is_row_major(d0_layout))); + ck_tile::HostTensor d1_m_n_tensors( + host_tensor_descriptor(M, N, StrideD1, is_row_major(d1_layout))); + + ck_tile::HostTensor e_m_n_device_result( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(b0_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b1_k_n_tensors); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); + a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); + + b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data()); + b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data()); + + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), + a1_m_k_dev_buf.GetDeviceBuffer()}; + + std::array bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(), + b1_k_n_dev_buf.GetDeviceBuffer()}; + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array strideAs = {StrideA0, StrideA1}; + std::array strideBs = {StrideB0, StrideB1}; + std::array strideDs = {StrideD0, StrideD1}; + + invoke_gemm_multi_abd(as_ptr_buf, + bs_ptr_buf, + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + strideAs, + strideBs, + strideDs, + StrideE, + n_warmup, + n_repeat, + k_batch); + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + + ck_tile::HostTensor a_m_k_host_ref_element_result( + host_tensor_descriptor(M, K, StrideA0, is_row_major(a0_layout))); + ck_tile::HostTensor b_k_n_host_ref_element_result( + host_tensor_descriptor(K, N, StrideB0, is_row_major(b0_layout))); + ck_tile::HostTensor e_m_n_host_ref( + host_tensor_descriptor(M, N, StrideE, is_row_major(e_layout))); + a_m_k_host_ref_element_result.SetZero(); + b_k_n_host_ref_element_result.SetZero(); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_abd({a0_m_k_tesnor, a1_m_k_tesnor}, + {b0_k_n_tensors, b1_k_n_tensors}, + {d0_m_n_tensors, d1_m_n_tensors}, + a_m_k_host_ref_element_result, + b_k_n_host_ref_element_result, + e_m_n_host_ref); + + bool pass{true}; + if(arg_parser.get_int("v")) + { + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + + const auto rtol_atol = calculate_rtol_atol(K, 1, max_accumulated_value); + + pass &= ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << std::endl; + std::cout << "Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU veification result is: " << (pass ? "correct" : "fail") << std::endl; + } + return pass; +} + +template +int run_multiple_abd_gemm_example(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + return -1; + } + + const std::string as_layout = arg_parser.get_str("as_layout"); + const std::string bs_layout = arg_parser.get_str("bs_layout"); + + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if(as_layout == "R" && bs_layout == "C") + { + return run_gemm_multi_abd_example_with_layouts( + argc, argv, Row{}, Row{}, Col{}, Col{}, Row{}, Row{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported data layout configuration for A,B and C tensors!"); + } +} diff --git a/example/ck_tile/22_gemm_multi_abd/utils.hpp b/example/ck_tile/22_gemm_multi_abd/utils.hpp new file mode 100644 index 0000000000..38bf8623d4 --- /dev/null +++ b/example/ck_tile/22_gemm_multi_abd/utils.hpp @@ -0,0 +1,38 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 8fce70ba04..75d32a5eb0 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -21,6 +21,7 @@ add_subdirectory(18_flatmm) add_subdirectory(19_gemm_multi_d) add_subdirectory(20_grouped_convolution) add_subdirectory(21_elementwise) +add_subdirectory(22_gemm_multi_abd) add_subdirectory(35_batched_transpose) add_subdirectory(38_block_scale_gemm) add_subdirectory(39_copy) diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 9c3967d99b..0c4f056a46 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -129,5 +129,7 @@ inline bool is_wmma_supported() return is_gfx103_supported() || is_gfx11_supported() || is_gfx12_supported(); } +inline bool is_tf32_supported() { return (ck::get_device_name() == "gfx942") ? true : false; } + } // namespace ck #endif diff --git a/include/ck/library/utility/check_err.hpp b/include/ck/library/utility/check_err.hpp index d33ecaeef8..185166f7ec 100644 --- a/include/ck/library/utility/check_err.hpp +++ b/include/ck/library/utility/check_err.hpp @@ -180,13 +180,13 @@ check_err(const Range& out, if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r)) { max_err = err > max_err ? err : max_err; - err_count++; if(err_count < 5) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; } res = false; + err_count++; } } if(!res) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp index e848ca35b5..55015dd30f 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp @@ -49,6 +49,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using ThisThreadBlock = ThisThreadBlock; + using ElementDataTypeA = + conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + conditional_t, float, ComputeTypeB>; + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); static constexpr index_t KPerBlock = @@ -64,7 +69,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static constexpr index_t WaveSize = BlockSize / MWaves / NWaves; static constexpr auto xdlops_gemm = - XdlopsGemm{}; + XdlopsGemm{}; static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops; @@ -172,6 +177,11 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, "wrong!"); + if constexpr(is_same_v || is_same_v) + { + static_assert(is_same_v, + "ComputeTypeA and ComputeTypeB must be same when one of them is tf32"); + } } __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2() @@ -297,9 +307,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, MRepeat, 1>{}([&](auto m0) { @@ -321,20 +331,20 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 b_thread_buf); static_for<0, KPerThread, KPack>{}([&](auto k) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = a_thread_buf + a_thread_vec.template AsType()(i) = a_thread_buf [Number{}]; - b_thread_vec.template AsType()(i) = b_thread_buf + b_thread_vec.template AsType()(i) = b_thread_buf [Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -361,7 +371,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, Number{}, xdlops_gemm.GetRegSizePerXdlops())); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -371,7 +381,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -445,6 +455,11 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 using Base::KPerThread; using Base::xdlops_gemm; + using ElementDataTypeA = + conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + conditional_t, float, ComputeTypeB>; + static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack); // 2-wave optimized blockwise gemm @@ -453,9 +468,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 const BBlockBuffer& b_block_buf, CThreadBuffer& c_thread_buf) const { - auto a_thread_buf = make_static_buffer( + auto a_thread_buf = make_static_buffer( a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( + auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); static_for<0, KPerThread, KPerInnerLoop>{}([&](auto k) { @@ -499,22 +514,22 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) { static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - vector_type a_thread_vec; - vector_type b_thread_vec; + vector_type a_thread_vec; + vector_type b_thread_vec; static_for<0, KPack, 1>{}([&](auto i) { - a_thread_vec.template AsType()(i) = + a_thread_vec.template AsType()(i) = a_thread_buf[Number{}]; - b_thread_vec.template AsType()(i) = + b_thread_vec.template AsType()(i) = b_thread_buf[Number{}]; }); using mfma_input_type_a = - typename vector_type::type; + typename vector_type::type; using mfma_input_type_b = - typename vector_type::type; + typename vector_type::type; constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); @@ -563,7 +578,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 make_tuple(Number{}, I1, I1, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -573,7 +588,7 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, @@ -622,19 +637,21 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector() } else if constexpr(LoopSched == LoopScheduler::Interwave) { - return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + return BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1< + BlockSize, + FloatA, + FloatB, + FloatAcc, + AK0MK1BlockDesc, + BK0NK1BlockDesc, + MPerXDL, + NPerXDL, + MRepeat, + NRepeat, + KPack, + ComputeTypeA, + ComputeTypeB, + CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS>{}; } }; diff --git a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp index 7296e4faaa..18223c78f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp +++ b/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp @@ -11,6 +11,8 @@ namespace ck { namespace tensor_operation { namespace device { +#define DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS 1 + template #include +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -853,7 +854,10 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle arg.e_grid_desc_m_n_, arg.block_2_ctile_map_)) { - printf("GridwiseOp: Validity check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Validity check failure\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 55aa7b59ee..72191632d8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include @@ -171,8 +172,8 @@ struct DeviceGemm_Wmma_CShuffleV3_Common // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot // be odd. constexpr bool AtomicsImplementationExists = - !(std::is_same_v || - std::is_same_v) || + !(std::is_same_v || std::is_same_v || + std::is_same_v) || (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); if(has_main_k_block_loop) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp index 8daaafaed1..23b0faec67 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_lds_direct_load.hpp @@ -119,7 +119,9 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm; + PipelineVer, + ComputeDataType>; + using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; @@ -214,6 +216,14 @@ struct DeviceGemm_Xdl_CShuffle_LdsDirectLoad : public DeviceGemm) + { + if(!is_tf32_supported()) + { + return false; + } + } + // Check vector load/store. { using Row = ck::tensor_layout::gemm::RowMajor; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp index 5d68ca720a..be94da1e50 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp @@ -144,18 +144,39 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); - if(split_k < 0) + if constexpr(IsTwoStageNeeded) { - const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = - DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); - const index_t grid_size = gdx * gdy * gdz; - split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + if(split_k < 0) + { + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + } + else + { + split_k_ = split_k; + } } else { - split_k_ = split_k; +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(split_k < 0) + { + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); + } + else +#endif + { + split_k_ = split_k; + } } if constexpr(IsTwoStageNeeded) @@ -318,6 +339,16 @@ struct DeviceGroupedConvBwdWeight_Explicit_Xdl static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if constexpr(!IsTwoStageNeeded) + { + if(arg.split_k_ < 0) + { + return false; + } + } +#endif + if constexpr(NDimSpatial == 2) { if constexpr(!is_NHWGC_GKYXC_NHWGK()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp index 934dc7ee8e..987a1e273a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_xdl_cshuffle.hpp @@ -671,6 +671,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -683,6 +684,7 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle grid_size); } else +#endif { k_batch_ = split_k; } @@ -939,6 +941,12 @@ struct DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp index b361409e38..22fc13bae4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp @@ -553,6 +553,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle conv_ngchw_to_nhwgc_transformer.TransposeWeiStrides(e_g_k_c_xs_lengths, e_g_k_c_xs_strides); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN; @@ -565,6 +566,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle grid_size); } else +#endif { k_batch_ = split_k; } @@ -934,6 +936,12 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif if(!ck::is_xdl_wmma_supported()) { return false; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index 8bf188be2e..735eebbdf6 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -524,6 +524,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 end(a_g_n_k_wos_lengths), begin(output_spatial_lengths_)); +#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS if(split_k < 0) { ck::index_t gemmM, gemmN, gemmK; @@ -549,6 +550,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 } } else +#endif { k_batch_ = split_k; } @@ -1275,6 +1277,13 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 static bool IsSupportedArgument(const Argument& arg) { +#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS + if(arg.k_batch_ < 0) + { + return false; + } +#endif + const index_t GemmM = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const index_t GemmN = arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1); const index_t GemmK = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) * diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 1412c960c7..cc8561a09f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1003,11 +1003,20 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle void Print() const { + std::cout << "AComputeDataType: " << get_type_name() + << "; BComputeDataType: " << get_type_name() + << "; EDataType: " << get_type_name() << std::endl; + std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl; std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl; static_for<0, NumDTensor, 1>{}( [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; }); std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl; + + std::cout << "a grid desc" << a_grid_desc_ak0_m_ak1_ << std::endl; + std::cout << "b grid desc" << b_grid_desc_bk0_n_bk1_ << std::endl; + std::cout << "e grid desc" << e_grid_desc_mblock_mperblock_nblock_nperblock_ + << std::endl; } // private: @@ -1198,7 +1207,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle isMultiA, isMultiB, CTranspose>; - return launch_and_time_kernel( stream_config, kernel, @@ -1281,7 +1289,6 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float avg_time = 0.f; - if constexpr(NeedTransposeKernel) { const index_t a_grid_size = @@ -1686,7 +1693,23 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle { return false; } - + if constexpr(is_same_v || + is_same_v) + { + if(!is_tf32_supported()) + { + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } // check Gridwise GEMM if(get_warp_size() == 64) { @@ -1766,6 +1789,28 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle } } + if constexpr(is_same_v || + is_same_v) + + { + if(!(ck::get_device_name() == "gfx942")) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "TF32 is enabled on gfx942 only" << std::endl; + } + return false; + } + if constexpr(!is_same_v) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ComputeDataType for A and B should be same while using TF32" + << std::endl; + } + return false; + } + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index b61c7a09eb..fa7eb4faaa 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -398,41 +398,54 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) { - print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", - M, - N, - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + } return false; } if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { - print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " - "%d, %d, %d, %d\n", - M, - L, - K, - N, - MPerBlock, - LPerBlock, - KPerBlock, - NPerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | " + "M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + } return false; } // check gemm1 gridwise gemm pipeline if(!(LPerBlock % LTilePerBlock == 0)) { - print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", - LPerBlock, - LTilePerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + } return false; } if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n)) { - print("GridwiseOp: invalid block_2_ctile_map\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + print("GridwiseOp: invalid block_2_ctile_map\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp index 1754e07e6a..502c449ef1 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -569,26 +570,33 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1))) { - printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", - M, - N, - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: M/N Length err, A_M/N = %d, %d | C_M/N = %d, %d\n", + M, + N, + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + } return false; } if(!(M % MPerBlock == 0 && L % LPerBlock == 0 && K % KPerBlock == 0 && N % NPerBlock == 0)) { - printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | M/L/K/NPerBlock = " - "%d, %d, %d, %d\n", - M, - L, - K, - N, - MPerBlock, - LPerBlock, - KPerBlock, - NPerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: M/L/K/N Division err, M/L/K/N = %d, %d, %d, %d | " + "M/L/K/NPerBlock = " + "%d, %d, %d, %d\n", + M, + L, + K, + N, + MPerBlock, + LPerBlock, + KPerBlock, + NPerBlock); + } return false; } @@ -596,23 +604,32 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma const auto num_gemm0_k_loop = K / KPerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop)) { - printf("GridwiseOp: outer loop unsupport\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: outer loop unsupport\n"); + } return false; } // check gemm1 gridwise gemm pipeline if(!(LPerBlock % LTilePerBlock == 0)) { - printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", - LPerBlock, - LTilePerBlock); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: inner loop division, L/LTilePerblock: %d, %d\n", + LPerBlock, + LTilePerBlock); + } return false; } const auto num_gemm1_k_inner_loop = LPerBlock / LTilePerBlock; if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop)) { - printf("GridwiseOp: inner loop unsupport\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: inner loop unsupport\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp index 8011fa56d3..c8b154228f 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -466,20 +467,26 @@ struct GridwiseFpAintBGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", - GetAProblemsizeMK()[I0], - GetAProblemsizeMK()[I1], - GetBProblemsizeNK()[I0], - GetBProblemsizeNK()[I1], - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - printf("GridwiseOp err: ProblemSize check"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp err: ProblemSize division"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: ProblemSize division"); + } return false; } @@ -488,7 +495,10 @@ struct GridwiseFpAintBGemm_Wmma if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - printf("GridwiseOp err: Pipeline not support this k_loop"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp index c198711dbb..cbad6a5673 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp @@ -708,7 +708,9 @@ struct GridwiseGemmMultipleABD_xdl_cshuffle MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeDataType, + BComputeDataType>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index 46979a5620..7d68d64ed8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -653,13 +654,19 @@ struct GridwiseGemmMultipleD_Wmma if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + } return false; } @@ -747,20 +754,29 @@ struct GridwiseGemmMultipleD_Wmma if(!valid) { - printf("GridwiseOp: D descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: D descriptor dimension check failure\n"); + } return false; } if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: ABE descriptor dimension cross check failure\n"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp: Problemsize descriptor dimension check failure\n"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp index 59d7f357ec..a97e4503a8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp @@ -107,8 +107,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle using BComputeDataType = conditional_t, ck::bhalf_t, BComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; - using BComputeDataType = BComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -659,26 +661,27 @@ struct GridwiseGemmMultipleD_xdl_cshuffle : false; constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); - - auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< - BlockSize, - AComputeDataType, - BComputeDataType, - AccDataType, - decltype(a_block_desc_ak0_m_ak1), - decltype(b_block_desc_bk0_n_bk1), - MPerXdl, - NPerXdl, - MXdlPerWave, - NXdlPerWave, - KPack, - LoopSched>(); + auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector< + BlockSize, + AComputeDataType, + BComputeDataType, + AccDataType, + decltype(a_block_desc_ak0_m_ak1), + decltype(b_block_desc_bk0_n_bk1), + MPerXdl, + NPerXdl, + MXdlPerWave, + NXdlPerWave, + KPack, + LoopSched, + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp index 095b1c5d63..1e72e78349 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp @@ -144,7 +144,7 @@ template + typename BComputeDataType_ = AComputeDataType_> struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad { static constexpr index_t NumDTensor = DsDataType::Size(); @@ -172,7 +172,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad using AComputeDataType = conditional_t, ck::bhalf_t, AComputeDataType_>; #else - using AComputeDataType = AComputeDataType_; + using AComputeDataType = + conditional_t, float, AComputeDataType_>; + using BComputeDataType = + conditional_t, float, BComputeDataType_>; #endif __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -573,7 +576,6 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad // This forces m/n_block_data_idx_on_grid into SGPR. const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); - const index_t n_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock); @@ -640,10 +642,10 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad constexpr auto is_scale_mfma = false; constexpr index_t KPack = math::max(lcm_AK1_BK1, - MfmaSelector::selected_mfma.k_per_blk); @@ -659,7 +661,9 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad MXdlPerWave, NXdlPerWave, KPack, - LoopSched>(); + LoopSched, + AComputeDataType_, + BComputeDataType_>(); auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp index 4a15958adb..65f74de3cf 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp @@ -1,8 +1,9 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -458,20 +459,26 @@ struct GridwiseGemm_Wmma if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) && K == GetBProblemsizeNK()[I1])) { - printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", - GetAProblemsizeMK()[I0], - GetAProblemsizeMK()[I1], - GetBProblemsizeNK()[I0], - GetBProblemsizeNK()[I1], - c_grid_desc_m_n.GetLength(I0), - c_grid_desc_m_n.GetLength(I1)); - printf("GridwiseOp err: ProblemSize check"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("A: MxK = %d x %d, B: NxK = %d x %d, C: MxN = %d x %d\n", + GetAProblemsizeMK()[I0], + GetAProblemsizeMK()[I1], + GetBProblemsizeNK()[I0], + GetBProblemsizeNK()[I1], + c_grid_desc_m_n.GetLength(I0), + c_grid_desc_m_n.GetLength(I1)); + printf("GridwiseOp err: ProblemSize check"); + } return false; } if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0)) { - printf("GridwiseOp err: ProblemSize division"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: ProblemSize division"); + } return false; } @@ -480,7 +487,10 @@ struct GridwiseGemm_Wmma if(!GridwiseGemmPipe::IsSupported(num_k_loop)) { - printf("GridwiseOp err: Pipeline not support this k_loop"); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + printf("GridwiseOp err: Pipeline not support this k_loop"); + } return false; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index b226730a09..59d3a6a4c5 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -1065,6 +1065,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + if constexpr(is_same, int8_t>::value) + { + if(karg.KBatch > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch + << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + return false; + } + } + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return true; } diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index deea6ae9cc..a97d9589cf 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -41,11 +41,11 @@ static constexpr bool scale_mfma_hw_support() enum struct MfmaInstr { - mfma_f32_32x32x1xf32 = 0, - mfma_f32_16x16x1xf32, - mfma_f32_4x4x1xf32, - mfma_f32_32x32x2xf32, - mfma_f32_16x16x4xf32, + mfma_f32_32x32x1f32 = 0, + mfma_f32_16x16x1f32, + mfma_f32_4x4x1f32, + mfma_f32_32x32x2f32, + mfma_f32_16x16x4f32, mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, @@ -78,6 +78,8 @@ enum struct MfmaInstr mfma_f32_16x16x128f8f6f4, mfma_scale_f32_32x32x64f8f6f4, mfma_scale_f32_16x16x128f8f6f4, + mfma_f32_16x16x8xf32, // tf32 + mfma_f32_32x32x4xf32, // gfx11 wmma_f32_16x16x16_f16, wmma_f32_16x16x16_bf16, @@ -98,7 +100,7 @@ template struct mfma_type; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -120,7 +122,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 4; @@ -142,7 +144,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -164,7 +166,7 @@ struct mfma_type }; template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -187,7 +189,7 @@ struct mfma_type // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_type +struct mfma_type { static constexpr index_t group_size = 4; static constexpr index_t num_groups_per_blk = 1; @@ -947,6 +949,70 @@ struct mfma_type } }; +/** + * num_threads_per_blk == n_per_blk + * num_regs_per_blk * num_input_blks == m_per_blk + * num_regs_per_blk * wave_size == m_per_blk * n_per_blk + * + * group_size * num_groups_per_blk == num_regs_per_blk + * + * num_regs_per_blk is output(CD) register size which is determined by the instruction. + * k_per_blk(K1PerXdlops) is input(AB) register size which is determined by the instruction. + * group_size is corresponding to CD rows mapping. see: GetBeginOfThreadBlk() + * + * is_k_reduction = (k_per_blk == KPerXdlops) ? false: true. + * + * if (is_k_reduction){ + * num_output_blks == 1; + * } else { + * num_input_blks == num_output_blks; + * } + */ +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 16; // from the instruction + static constexpr index_t n_per_blk = 16; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 16 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4 + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2. + static constexpr bool is_k_reduction = true; + + // AB register size : 2, register size: 4 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_16x16x8xf32::Run(a, b, reg_c); + } +}; + +template <> +struct mfma_type +{ + static constexpr index_t wave_size = 64; // fixed + static constexpr index_t m_per_blk = 32; // from the instruction + static constexpr index_t n_per_blk = 32; // from the instruction + static constexpr index_t num_threads_per_blk = n_per_blk; // 32 + static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16 + static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2 + static constexpr index_t group_size = 4; // corresponding to CD rows mapping + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; + // AB register size: 2, CD register size: 16 + template + __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const + { + intrin_mfma_f32_32x32x4xf32::Run(a, b, reg_c); + } +}; + // gfx11 struct mfma_type_gfx11_base { @@ -1116,6 +1182,20 @@ struct mfma_type : public mfma_type_gfx12 } }; +/** + * @class MfmaSelector + * @brief Selects the appropriate MFMA instruction type and configuration for given data types + * and tile sizes on AMD GPUs. + * + * @tparam base_type The base data type for the matrix operation (e.g., float, half_t). + * @tparam MPerXdlops The number of rows per XDLops tile. + * @tparam NPerXdlops The number of columns per XDLops tile. + * @tparam additional_type (Optional) Additional data type for mixed-precision or special cases. + * Defaults to base_type. + * @tparam is_single_rate_mfma (Optional) Whether to use single-rate MFMA instructions. + * Defaults to false. + * @tparam is_scale_mfma (Optional) Whether to use scale MFMA instructions. Defaults to false. + */ template constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x1xf32; + return MfmaInstr::mfma_f32_32x32x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x1xf32; + return MfmaInstr::mfma_f32_32x32x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_16x16x1xf32; + return MfmaInstr::mfma_f32_16x16x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_4x4x1xf32; + return MfmaInstr::mfma_f32_4x4x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_4x4x1xf32; + return MfmaInstr::mfma_f32_4x4x1f32; } template <> constexpr auto GetMfma() { - return MfmaInstr::mfma_f32_32x32x2xf32; + return MfmaInstr::mfma_f32_32x32x2f32; } template <> @@ -1188,10 +1268,22 @@ struct MfmaSelector #elif defined(__gfx11__) return MfmaInstr::wmma_unsupport_16x16_gfx11; #else - return MfmaInstr::mfma_f32_16x16x4xf32; + return MfmaInstr::mfma_f32_16x16x4f32; #endif } + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4xf32; + } + + template <> + constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x8xf32; + } + template <> constexpr auto GetMfma() { @@ -1896,7 +1988,7 @@ struct XdlopsGemm __device__ __host__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_instr.wave_size; + return mfma_instr.num_regs_per_blk; } __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; } @@ -1906,12 +1998,12 @@ struct XdlopsGemm { static_assert( is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || is_same::value || - is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || + is_same::value || is_same::value || (is_same::value && is_same::value) || (is_same::value && is_same::value), - "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!"); + "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { if constexpr(!TransposeC) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index 2edbb7c789..0b73f76155 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -33,8 +33,34 @@ namespace ck { -using f8_fnuz_t = _BitInt(8); -using bf8_fnuz_t = unsigned _BitInt(8); +struct f8_fnuz_t +{ + using data_type = unsigned char; + data_type m_data; + __host__ __device__ explicit constexpr f8_fnuz_t(data_type in_data) : m_data(in_data) {} + __host__ __device__ explicit constexpr f8_fnuz_t() = default; + __host__ __device__ bool constexpr operator==(f8_fnuz_t other) const + { + return m_data == other.m_data; + } + __host__ __device__ explicit constexpr operator data_type() const { return m_data; } +}; + +struct bf8_fnuz_t +{ + using data_type = unsigned char; + data_type m_data; + __host__ __device__ explicit constexpr bf8_fnuz_t(data_type in_data) : m_data(in_data) {} + __host__ __device__ explicit constexpr bf8_fnuz_t() = default; + __host__ __device__ bool constexpr operator==(bf8_fnuz_t other) const + { + return m_data == other.m_data; + } + __host__ __device__ explicit constexpr operator data_type() const { return m_data; } +}; + +static_assert(1 == sizeof(f8_fnuz_t)); +static_assert(1 == sizeof(bf8_fnuz_t)); typedef unsigned char fp8_storage_t; diff --git a/include/ck/utility/amd_xdlops.hpp b/include/ck/utility/amd_xdlops.hpp index 02a7a72b8c..be3a5cea42 100644 --- a/include/ck/utility/amd_xdlops.hpp +++ b/include/ck/utility/amd_xdlops.hpp @@ -1636,4 +1636,45 @@ struct intrin_mfma_f32_16x16x32bf8f8<16, 16> } }; +/******************* tf32 *************************************/ +template +struct intrin_mfma_f32_16x16x8xf32; + +template <> +struct intrin_mfma_f32_16x16x8xf32<16, 16> +{ + template + __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + +template +struct intrin_mfma_f32_32x32x4xf32; + +template <> +struct intrin_mfma_f32_32x32x4xf32<32, 32> +{ + template + __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c) + { +#if defined(__gfx94__) + reg_c.template AsType()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); +#else + ignore = reg_a; + ignore = reg_b; + ignore = reg_c; +#endif + } +}; + } // namespace ck diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 5fbe30d21b..984bb4d862 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -26,6 +26,7 @@ using byte = unsigned char; using std::byte; #endif +using tf32_t = _BitInt(19); // 1 sign bit, 8 exponent bits, 10 mantissa bits using bhalf_t = ushort; using half_t = _Float16; using int4_t = _BitInt(4); @@ -204,7 +205,7 @@ inline constexpr bool is_native_type() return is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || is_same::value || - is_same::value || is_same::value || is_same::value; + is_same_v || is_same_v || is_same::value; } // scalar_type @@ -299,14 +300,14 @@ struct scalar_type template <> struct scalar_type { - using type = f8_fnuz_t; + using type = f8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; template <> struct scalar_type { - using type = bf8_fnuz_t; + using type = bf8_fnuz_t::data_type; static constexpr index_t vector_size = 1; }; @@ -461,4 +462,38 @@ using int64_t = long long; using int64_t = long; #endif +template +inline const char* get_type_name() +{ + if constexpr(is_same_v) + return "fp16"; + else if constexpr(is_same_v) + return "bf16"; + else if constexpr(is_same_v) + return "tf32"; + else if constexpr(is_same_v) + return "int4"; + else if constexpr(is_same_v) + return "f4"; + else if constexpr(is_same_v) + return "f6"; + else if constexpr(is_same_v) + return "bf6"; + else if constexpr(is_same_v) + return "f8"; + else if constexpr(is_same_v) + return "bf8"; + else if constexpr(is_same_v) + return "e8m0"; + else if constexpr(is_same_v) + return "fp32"; +#if defined(__HIPCC_RTC__) || defined(CK_CODE_GEN_RTC) + else + return "unknown"; +#else + else + return typeid(T).name(); +#endif +} + } // namespace ck diff --git a/include/ck/utility/dtype_vector.hpp b/include/ck/utility/dtype_vector.hpp index ae0edb35ee..27a7545a0e 100644 --- a/include/ck/utility/dtype_vector.hpp +++ b/include/ck/utility/dtype_vector.hpp @@ -1294,6 +1294,18 @@ struct nnvb_data_t_selector using type = bf8_ocp_t::data_type; }; +template <> +struct nnvb_data_t_selector +{ + using type = f8_fnuz_t::data_type; +}; + +template <> +struct nnvb_data_t_selector +{ + using type = bf8_fnuz_t::data_type; +}; + template <> struct nnvb_data_t_selector { diff --git a/include/ck/utility/f8_utils.hpp b/include/ck/utility/f8_utils.hpp index 799683ae65..748aa07f9e 100644 --- a/include/ck/utility/f8_utils.hpp +++ b/include/ck/utility/f8_utils.hpp @@ -39,7 +39,7 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) int exponent, bias; uint32_t head, mantissa, sign; // nan code is same for float and half - constexpr Y nan_code = 0x80; + constexpr uint8_t nan_code = 0x80; constexpr uint32_t nan_mask = NumericUtils::nan_mask; // convert to bitwise @@ -60,17 +60,17 @@ __host__ __device__ Y run_cast_to_f8(X x, uint32_t rng) if constexpr(negative_zero_nan) { if((x_bitwise & nan_mask) == nan_mask) - return nan_code; + return Y{nan_code}; } else { if((x_bitwise & nan_mask) == nan_mask) - return signed_inf + (mantissa != 0 ? 1 : 0); + return Y{static_cast(signed_inf + (mantissa != 0 ? 1 : 0))}; } // check if x is 0.0 if(x_bitwise == 0) - return 0; + return Y{0}; // First need to check if it is normal or denorm as there is a difference of implict 1 // Then need to adjust the exponent to align with the F8 exponent, in the meanwhile, shift @@ -178,9 +178,10 @@ In this case, the fp16 mantissa should be shift left by 1 */ // check if x is 0.0 or -0.0 if(out_exponent == 0 && mantissa == 0) - return negative_zero_nan ? 0 : (sign << (out_exp + out_mant)); + return Y{negative_zero_nan ? 0 : static_cast(sign << (out_exp + out_mant))}; mantissa &= (1 << out_mant) - 1; - return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa; + return Y{static_cast((sign << (out_exp + out_mant)) | (out_exponent << out_mant) | + mantissa)}; } template @@ -195,8 +196,8 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr int out_mant = NumericUtils::mant; // prepare the codes - constexpr X nan_code = 0x80; - using T_bitwise = typename NumericUtils::bitwise_type; + constexpr uint8_t nan_code = 0x80; + using T_bitwise = typename NumericUtils::bitwise_type; constexpr T_bitwise Inf_bitwise = NumericUtils::Inf; constexpr T_bitwise NegInf_bitwise = NumericUtils::NegInf; @@ -209,13 +210,13 @@ __host__ __device__ Y run_cast_from_f8(X x) constexpr Y Neg0 = bit_cast(Neg0_bitwise); // check if x is 0.0 - if(x == 0) + if(!static_cast(x)) return static_cast(0); // unpack the input - uint32_t sign = x >> (in_exp + in_mant); - uint32_t mantissa = x & ((1 << in_mant) - 1); - int exponent = (x & 0x7F) >> in_mant; + uint32_t sign = static_cast(x) >> (in_exp + in_mant); + uint32_t mantissa = static_cast(x) & ((1 << in_mant) - 1); + int exponent = (static_cast(x) & 0x7F) >> in_mant; constexpr int exp_low_cutoff = (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0); @@ -223,12 +224,12 @@ __host__ __device__ Y run_cast_from_f8(X x) if constexpr(negative_zero_nan) { - if(x == nan_code) + if(static_cast(x) == nan_code) return NaN; } else { - if(x == nan_code) + if(static_cast(x) == nan_code) return Neg0; if(exponent == ((1 << in_exp) - 1)) return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN; diff --git a/include/ck/utility/random_gen.hpp b/include/ck/utility/random_gen.hpp index c37d3922ca..2ff46457fc 100644 --- a/include/ck/utility/random_gen.hpp +++ b/include/ck/utility/random_gen.hpp @@ -3,6 +3,7 @@ #pragma once #include +#include #include "ck/ck.hpp" #ifdef CK_CODE_GEN_RTC @@ -17,7 +18,7 @@ namespace ck { template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { - uint32_t x = *(reinterpret_cast(&val)); + uint32_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits ^= x >> 16; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); @@ -33,7 +34,7 @@ __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = template {}, bool> = false> __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) { - uint16_t x = *(reinterpret_cast(&val)); + uint16_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits *= 0x7000149; diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 8e53728ef6..913557fc7a 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -187,6 +187,19 @@ inline __host__ __device__ constexpr bf8_ocp_t type_convert(int return bf8_ocp_t{type_convert(x)}; } +template , bool> = false> +inline __host__ __device__ constexpr float type_convert(float x) +{ + union + { + float fp32; + uint32_t int32; + } u = {x}; + + u.int32 = u.int32 & 0xffffe000; + return u.fp32; +} + // Convert X to Y template __host__ __device__ constexpr Y type_convert_sp(X x) @@ -338,7 +351,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return val.i8val[0]; // little endian + return f8_t{val.i8val[0]}; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -406,7 +419,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos val.i32val = ival; - return val.i8val[0]; // little endian + return bf8_t{val.i8val[0]}; // little endian #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -642,7 +655,7 @@ inline __host__ __device__ f8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_fp8, -max_fp8); ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return val.i8val[0]; + return f8_t{val.i8val[0]}; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -694,7 +707,7 @@ inline __host__ __device__ bf8_fnuz_t f8_convert_rne(float x) val.fval = __builtin_amdgcn_fmed3f(val.fval, max_bf8, -max_bf8); ival = __builtin_amdgcn_cvt_pk_bf8_f32(val.fval, val.fval, ival, false); // false -> WORD0 val.i32val = ival; - return val.i8val[0]; + return bf8_t{val.i8val[0]}; #else constexpr bool negative_zero_nan = true; constexpr bool clip = true; @@ -911,7 +924,7 @@ inline __host__ __device__ float type_convert(f8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(x); + uint32_t i32val = static_cast(static_cast(x)); fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0); // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; @@ -1417,7 +1430,7 @@ inline __host__ __device__ float type_convert(bf8_fnuz_t x) { #if defined(__gfx94__) float fval; - uint32_t i32val = static_cast(x); + uint32_t i32val = static_cast(static_cast(x)); fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0); // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val)); return fval; diff --git a/include/ck_tile/core/numeric/pk_fp4.hpp b/include/ck_tile/core/numeric/pk_fp4.hpp index f25b98f5a0..8b78990d08 100644 --- a/include/ck_tile/core/numeric/pk_fp4.hpp +++ b/include/ck_tile/core/numeric/pk_fp4.hpp @@ -23,7 +23,8 @@ using fp32x2_t = float __attribute__((ext_vector_type(2))); using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); -CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f); +struct pk_float4_e2m1_t; +CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t float_to_pk_fp4(const float& x, float scale = 1.f); // TODO: Add stochastic method struct pk_float4_e2m1_t @@ -31,7 +32,7 @@ struct pk_float4_e2m1_t // TODO: Can we merge raw_type and type? using raw_type = uint8_t; using type = raw_type; - raw_type data; + type data; CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t() : data{type{}} {} template >> @@ -39,12 +40,12 @@ struct pk_float4_e2m1_t { } CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init, float scale = 1.f) - : data{float_to_e2m1(init, scale)} + : data{float_to_pk_fp4(init, scale)} { } CK_TILE_HOST_DEVICE constexpr operator type() const { return data; } - CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; } - CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; } + CK_TILE_HOST_DEVICE constexpr type& get() { return data; } + CK_TILE_HOST_DEVICE constexpr type get() const { return data; } CK_TILE_HOST_DEVICE constexpr float to_float(float scale = 1.f) const; CK_TILE_HOST_DEVICE constexpr fp32x2_t to_fp32x2(float scale = 1.f) const; @@ -61,8 +62,19 @@ struct pk_float4_e2m1_t CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const { return to_bf16x2(); } template - CK_TILE_HOST_DEVICE constexpr raw_type unpack(number) const; - CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1) + CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t unpack(number) const + { + return _unpack(number{}); + } + CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const pk_float4_e2m1_t& x0, + const pk_float4_e2m1_t& x1) + { + return _pack(x0.get(), x1.get()); + } + + template + CK_TILE_HOST_DEVICE constexpr type _unpack(number) const; + CK_TILE_HOST_DEVICE constexpr static type _pack(const type x0, const type x1) { return (x1 << 4) | (x0 & 0b00001111); } @@ -92,7 +104,7 @@ struct pk_float4_e2m1_t }; using pk_fp4_t = pk_float4_e2m1_t; -using pk_fp4_raw_t = typename pk_fp4_t::raw_type; +using pk_fp4_raw_t = typename pk_fp4_t::type; template <> struct numeric_traits @@ -124,7 +136,7 @@ struct numeric CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; } - CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; } + CK_TILE_HOST_DEVICE static constexpr pk_fp4_t denorm_min() { return binary_min_subnorm; } CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; } // N/A @@ -136,7 +148,7 @@ struct numeric }; template -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number) const +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::_unpack(number) const { static_assert(I < 2, "Index is out of range."); if constexpr(I == 1) @@ -202,7 +214,7 @@ CK_TILE_HOST_DEVICE constexpr bf16_t pk_fp4_t::to_bf16(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return bf16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; + return bf16_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; #endif } @@ -211,13 +223,13 @@ CK_TILE_HOST_DEVICE constexpr bf16x2_t pk_fp4_t::to_bf16x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return bf16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), - type_convert(convert_to_float(unpack(number<1>{}), scale))}; + return bf16x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } -// TODO: make float_to_e2m1 generic so that we can convert from directrly. -CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) +// TODO: make it generic so that we can convert from directrly. +CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_mxfp4(float x, float scale) { #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); @@ -227,14 +239,20 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t float_to_e2m1(float x, float scale) } CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x, float scale) { - return float_to_e2m1(x, scale); +#if CK_TILE_FP4_CVT_DEVICE + return impl::_to_f4(x, scale); +#else + auto res = convert_to_type(x, scale); + return pk_fp4_t::_pack(res, res); +#endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16_to_pk_fp4(const fp16_t& x, float scale) { #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x), scale); + auto res = float_to_mxfp4(type_convert(x), scale); + return pk_fp4_t::_pack(res, res); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float scale) @@ -242,7 +260,8 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16_to_pk_fp4(const bf16_t& x, float sca #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return float_to_e2m1(type_convert(x), scale); + auto res = float_to_mxfp4(type_convert(x), scale); + return pk_fp4_t::_pack(res, res); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float scale) @@ -250,7 +269,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float scale) @@ -258,7 +277,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float scale) @@ -266,7 +285,7 @@ CK_TILE_HOST_DEVICE constexpr pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t& x, float #if CK_TILE_FP4_CVT_DEVICE return impl::_to_f4(x, scale); #else - return pk_fp4_t::pack(float_to_e2m1(x[0], scale), float_to_e2m1(x[1], scale)); + return pk_fp4_t::_pack(float_to_mxfp4(x[0], scale), float_to_mxfp4(x[1], scale)); #endif } @@ -301,7 +320,7 @@ CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return convert_to_float(unpack(number<0>{}), scale); + return convert_to_float(_unpack(number<0>{}), scale); #endif } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const @@ -309,8 +328,8 @@ CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp32x2_t{convert_to_float(unpack(number<0>{}), scale), - convert_to_float(unpack(number<1>{}), scale)}; + return fp32x2_t{convert_to_float(_unpack(number<0>{}), scale), + convert_to_float(_unpack(number<1>{}), scale)}; #endif } @@ -319,7 +338,7 @@ CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp16_t{type_convert(convert_to_float(unpack(number<0>{}), scale))}; + return fp16_t{type_convert(convert_to_float(_unpack(number<0>{}), scale))}; #endif } CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const @@ -327,28 +346,29 @@ CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const #if CK_TILE_FP4_CVT_DEVICE return impl::_from_f4(data, scale); #else - return fp16x2_t{type_convert(convert_to_float(unpack(number<0>{}), scale)), - type_convert(convert_to_float(unpack(number<1>{}), scale))}; + return fp16x2_t{type_convert(convert_to_float(_unpack(number<0>{}), scale)), + type_convert(convert_to_float(_unpack(number<1>{}), scale))}; #endif } #else CK_TILE_HOST_DEVICE constexpr float pk_fp4_t::to_float(float scale) const { - return e2m1_to_fp32_table[unpack(number<0>{})] * scale; + return e2m1_to_fp32_table[_unpack(number<0>{})] * scale; } CK_TILE_HOST_DEVICE constexpr fp32x2_t pk_fp4_t::to_fp32x2(float scale) const { - return fp32x2_t{e2m1_to_fp32_table[unpack(number<0>{})] * scale, e2m1_to_fp32_table[unpack(number<1>{}] * scale}; + return fp32x2_t{e2m1_to_fp32_table[_unpack(number<0>{})] * scale, e2m1_to_fp32_table[_unpack(number<1>{}] * scale}; } CK_TILE_HOST_DEVICE constexpr fp16_t pk_fp4_t::to_fp16(float scale) const { - return type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale; + return type_convert(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale; } CK_TILE_HOST_DEVICE constexpr fp16x2_t pk_fp4_t::to_fp16x2(float scale) const { return fp16x2_t{ - type_convert(type_convert(e2m1_to_fp16_table[unpack(number<0>{})]) * scale), - type_convert(type_convert(e2m1_to_fp16_table[unpack(number<1>{})]) * scale)}; + type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<0>{})]) * scale), + type_convert(type_convert(e2m1_to_fp16_table[_unpack(number<1>{})]) * + scale)}; } #endif diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 8b7541bf23..c7c4702e22 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -26,6 +26,29 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, return tile_window.load(number{}, bool_constant{}); } +/** + * @brief Load tile with elementwise function + * + * @note This function is a modification of the existing load function. + * It has been extended with two additional parameters: it takes a tuple as input + * and an elementwise function. For each A = A0, A1… AN, the elementwise function + * is additionally applied during a single read. + */ +template +CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) +{ + // TODO: Tile windows should works with unknow number of params + // Load element_wise API works only when the input typle is a tuple-tyupe + return tile_window[number<0>{}].load( + tile_window, elementwise, number{}, bool_constant{}); +} + template + CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + constexpr auto tile_dstr = typename Base::TileDstr{}; + auto dst_tensor = make_static_distributed_tensor(tile_dstr); + load(dst_tensor, + tile_window, + elementwise, + number{}, + bool_constant{}); + return dst_tensor; + } + + template + CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor, + const TileWindow_& tile_window, + ElementWise_ elementwise, + number = {}, + bool_constant = {}) const + { + + using Traits = typename Base::Traits; + using vector_t = typename Traits::vector_t; + using SFC_Ys = typename Traits::SFC_Ys; + + constexpr auto tile_dstr = typename Base::TileDstr{}; + constexpr auto sizeOfTuple = TileWindow_::size(); + // loop over thread tensor space [y0, y1, ...] + static_for<0, NumCoord, 1>{}([&](auto iCoord) { + /// TODO: use structure binding (to be captured later) if compiled in C++20 + auto window_adaptor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + auto bottom_tensor_thread_coord = + tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + + static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { + constexpr auto iAccess = number{}; + + // data index [y0, y1, ...] + constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess); + + // read from bottom tensor + const auto idx_vec_value = generate_tuple( + [&](auto jj) { + return tile_window[number{}] + .get_bottom_tensor_view() + .template get_vectorized_elements( + bottom_tensor_thread_coord, + 0, + bool_constant{}); + }, + number{}); + + // write into distributed tensor + static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { + constexpr auto idx_ys = generate_tuple( + [&](auto jj) { + return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) + : idx_ys_start[jj]; + }, + number{}); + + constexpr index_t d = + tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / + Traits::PackedSize; + + ck_tile::apply( + [&](auto&&... t) { + elementwise(dst_tensor.get_thread_buffer().template at(), + t.template get_as< + typename Base::DataType>()[j / Traits::PackedSize]...); + }, + idx_vec_value); + }); + // move thread coordinate + if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) + { + constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess); + + constexpr auto idx_diff_ps_ys = container_concat( + generate_tuple([&](auto) { return number<0>{}; }, number{}), + idx_diff_ys); + + Base::move_window_adaptor_and_bottom_tensor_thread_coordinate( + window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys); + } + }); + }); + } + template @@ -857,6 +967,39 @@ CK_TILE_DEVICE void move_tile_window( window.move(step); } +template +CK_TILE_DEVICE void move_tile_window( + tuple>& window, + const typename tile_window_with_static_distribution::BottomTensorIndex& step) +{ + using T = tuple>; + + static constexpr auto N = T::size(); + static_for<0, N, 1>{}([&](auto Is) { window[number{}].move(step); }); +} + +template ::value>* = nullptr> +CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step) +{ + static constexpr auto N = TileWindowWithStaticDistributionType::size(); + static_for<0, N, 1>{}([&](auto Is) { window[number{}].move(step); }); +} + /** * @brief This class provides description of tile windowed view on the device memory. * @@ -887,6 +1030,58 @@ struct tile_window_with_static_lengths this->window_lengths_ = window_lengths; this->bottom_tensor_view_ = bottom_tensor_view; } + + /** + * @brief Print tile window elements for debugging. + * + * @tparam DataType Element data type (e.g., fp16_t, float, bf8_t) + * @param start_i Starting row (inclusive) + * @param end_i Ending row (exclusive) + * @param start_j Starting column (inclusive) + * @param end_j Ending column (exclusive) + * @param label Optional output label + * + * @note Tested on fp16. Custom types may need adjustments. + * @example tile_window.template print_tile_window_range(0, 4, 0, 8, "A"); + */ + template + CK_TILE_DEVICE void print_tile_window_range(index_t start_i, + index_t end_i, + index_t start_j, + index_t end_j, + const char* label = "") const + { + const auto& tensor_view = this->get_bottom_tensor_view(); + const auto window_origin = this->get_window_origin(); + + printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n", + label, + start_i, + end_i - 1, + start_j, + end_j - 1, + window_origin[0], + window_origin[1]); + + for(index_t i = start_i; i < end_i; i++) + { + for(index_t j = start_j; j < end_j; j++) + { + // Create coordinate for this element relative to window origin + auto coord = + make_tensor_coordinate(tensor_view.get_tensor_descriptor(), + make_tuple(window_origin[0] + i, window_origin[1] + j)); + + // Get the element using thread buffer type directly + using ThreadBuf = thread_buffer; + auto buf = tensor_view.template get_vectorized_elements(coord, 0); + auto value = buf.at(number<0>{}); // Extract first element from thread buffer + printf(" %s[%d,%d] = %f", label, i, j, static_cast(value)); + } + printf("\n"); + } + printf("\n"); + } }; template diff --git a/include/ck_tile/core/utility/random.hpp b/include/ck_tile/core/utility/random.hpp index f7fbfad4dd..6a38ad3bde 100644 --- a/include/ck_tile/core/utility/random.hpp +++ b/include/ck_tile/core/utility/random.hpp @@ -24,7 +24,7 @@ struct prand_generator_t { CK_TILE_HOST_DEVICE uint32_t operator()(int id, float val, uint32_t seed = seed_) { - uint32_t x = *(reinterpret_cast(&val)); + uint32_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits ^= x >> 16; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); @@ -43,7 +43,7 @@ struct prand_generator_t { CK_TILE_HOST_DEVICE uint32_t operator()(int id, half_t val, uint32_t seed = seed_) { - uint16_t x = *(reinterpret_cast(&val)); + uint16_t x = bit_cast(val); uint32_t drop_bits = uint32_t(x) & 0xFFFFu; drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); drop_bits *= 0x7000149; diff --git a/include/ck_tile/host/fill.hpp b/include/ck_tile/host/fill.hpp index e03881a1c7..817a46a8ea 100644 --- a/include/ck_tile/host/fill.hpp +++ b/include/ck_tile/host/fill.hpp @@ -67,7 +67,10 @@ struct FillUniformDistribution : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); std::generate(first + iw_begin, first + iw_end, [&dis, &gen]() { - return ck_tile::type_convert(dis(gen)); + if constexpr(numeric_traits::PackedSize == 2) + return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); + else + return ck_tile::type_convert(dis(gen)); }); }; threads[it] = joinable_thread(thread_f); @@ -77,8 +80,12 @@ struct FillUniformDistribution { std::mt19937 gen(seed_.has_value() ? *seed_ : std::random_device{}()); std::uniform_real_distribution dis(a_, b_); - std::generate( - first, last, [&dis, &gen]() { return ck_tile::type_convert(dis(gen)); }); + std::generate(first, last, [&dis, &gen]() { + if constexpr(numeric_traits::PackedSize == 2) + return ck_tile::type_convert(fp32x2_t{dis(gen), dis(gen)}); + else + return ck_tile::type_convert(dis(gen)); + }); } } diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index caa00e5994..d9379b4420 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -261,6 +261,81 @@ CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } +template >, + typename BDataType = remove_cvref_t>, + typename DDataType = remove_cvref_t>> +CK_TILE_HOST void +reference_gemm_multiple_abd(const std::array, AsDataType::size()>& as_m_k, + const std::array, BsDataType::size()>& bs_k_n, + const std::array, DsDataType::size()>& ds_m_n, + HostTensor& a_m_k, + HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const CDElementOp& acc_element_op = {}) +{ + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto as_m_k_tuple = + generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number{}); + + auto bs_k_n_tuple = + generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number{}); + + auto ds_m_n_tuple = + generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number{}); + + // Apply elementwise function to A + auto a_elementwise_fn = [&](auto i, auto j) { + ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple); + }; + + make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency()); + + // Apply elementwise function to B + auto b_elementwise_fn = [&](auto i, auto j) { + ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple); + }; + + make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency()); + + auto f_mk_kn_mn = [&](auto m, auto n) { + AccDataType v_acc = 0; + for(std::size_t k = 0; k < K; ++k) + { + ADataType v_a = a_m_k(m, k); + BDataType v_b = b_k_n(k, n); + v_acc += + ck_tile::type_convert(v_a) * ck_tile::type_convert(v_b); + } + + CDataType v_c = 0; + + ck_tile::apply( + [&](auto&&... t) { + acc_element_op(v_c, + ck_tile::type_convert(v_acc), + ck_tile::type_convert(t(m, n))...); + }, + ds_m_n_tuple); + + c_m_n(m, n) = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency()); +} + template () + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 333762e5d7..9c5d99efc3 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -45,47 +45,57 @@ struct Generic2dBlockShape static constexpr index_t Block_N = BlockTile_::at(number<1>{}); static constexpr index_t ThreadPerBlock_M = ThreadPerBlock_::at(number<0>{}); static constexpr index_t ThreadPerBlock_N = ThreadPerBlock_::at(number<1>{}); - static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; // vector size along seq static constexpr index_t Vector_M = Vector_::at(number<0>{}); static constexpr index_t Vector_N = Vector_::at(number<1>{}); - static constexpr bool is_warp_per_row = ThreadPerBlock_N <= get_warp_size(); - static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % get_warp_size() == 0); - static constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / get_warp_size(); - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = []() { + template + static constexpr index_t GetWarpPerBlock_M() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; + static_assert((ThreadPerBlock_M * ThreadPerBlock_N) % warp_size == 0); + constexpr index_t total_warps = (ThreadPerBlock_M * ThreadPerBlock_N) / warp_size; + if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); - return total_warps * (get_warp_size() / ThreadPerBlock_N); + static_assert(warp_size % ThreadPerBlock_N == 0); + return total_warps * (warp_size / ThreadPerBlock_N); } else { // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); - return total_warps / (ThreadPerBlock_N / get_warp_size()); + return total_warps / (ThreadPerBlock_N / warp_size); } - }(); + }; // num of warps along n - static constexpr index_t WarpPerBlock_N = []() { + template + static constexpr index_t GetWarpPerBlock_N() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + constexpr bool is_warp_per_row = ThreadPerBlock_N <= warp_size; if constexpr(is_warp_per_row) { - static_assert(get_warp_size() % ThreadPerBlock_N == 0); + static_assert(warp_size % ThreadPerBlock_N == 0); return 1; } else { - static_assert(ThreadPerBlock_N % get_warp_size() == 0); - return ThreadPerBlock_N / get_warp_size(); + static_assert(ThreadPerBlock_N % warp_size == 0); + return ThreadPerBlock_N / warp_size; } - }(); + } + + static constexpr index_t WarpPerBlock_M = GetWarpPerBlock_M(); + static constexpr index_t WarpPerBlock_N = GetWarpPerBlock_N(); // warp size - static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; - static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; + static constexpr index_t BlockSize = WarpPerBlock_M * WarpPerBlock_N * get_warp_size(); + static constexpr index_t Warp_M = ThreadPerBlock_M / WarpPerBlock_M * Vector_M; + static constexpr index_t Warp_N = ThreadPerBlock_N / WarpPerBlock_N * Vector_N; static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); @@ -98,6 +108,13 @@ struct Generic2dBlockShape // num of threads along seq, within each warp static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + template + static constexpr index_t GetBlockSize() + { + constexpr index_t warp_size = isHostWave32 ? 32 : get_warp_size(); + return GetWarpPerBlock_M() * GetWarpPerBlock_N() * warp_size; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp new file mode 100644 index 0000000000..f8432b9da0 --- /dev/null +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/ops/elementwise.hpp" + +namespace ck_tile { + +template +struct is_pk_int4 : std::false_type +{ +}; +template <> +struct is_pk_int4 : std::true_type +{ +}; + +template +struct InterleavedPKTypeLoader +{ + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) + { + const element_wise::PassThroughPack8 elementwise_op{}; + + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } +}; + +template +CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) +{ + if constexpr(is_pk_int4>::value) + { + InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + } + else + { + dst = load_tile(src); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 9e3ccb025d..221592ee10 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -162,6 +162,16 @@ CK_TILE_DEVICE bf16x4_t i4_to_bhalf4(int q) */ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) { +#if CK_TILE_USE_OCP_FP8 + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0xcaccced0; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0xb8c0c4c8; + // register values [-1, -2, -3, -4] + static constexpr uint32_t reg2 = 0x44403800; + // register values [-5, -6, -7, -8] + static constexpr uint32_t reg3 = 0x4e4c4a48; +#else // register values [3, 2, 1, 0] static constexpr uint32_t reg0 = 0xd2d4d6d8; // register values [7, 6, 5, 4] @@ -170,6 +180,7 @@ CK_TILE_DEVICE fp8x8_t amd_assembly_i4_to_fp8x8(int a) static constexpr uint32_t reg2 = 0x4C484000; // register values [-5, -6, -7, -8] static constexpr uint32_t reg3 = 0x56545250; +#endif uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; @@ -227,6 +238,16 @@ CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src) */ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) { +#if CK_TILE_USE_OCP_FP8 + // register values [3, 2, 1, 0] + static constexpr uint32_t reg0 = 0Xc5c6c7c8; + // register values [7, 6, 5, 4] + static constexpr uint32_t reg1 = 0Xbcc0c2c4; + // register values [11, 10, 9, 8] + static constexpr uint32_t reg2 = 0X42403c00; + // register values [15, 14, 13, 12] + static constexpr uint32_t reg3 = 0X47464544; +#else // register values [3, 2, 1, 0] static constexpr uint32_t reg0 = 0Xc9cacbcc; // register values [7, 6, 5, 4] @@ -235,6 +256,7 @@ CK_TILE_DEVICE bf8x8_t amd_assembly_i4_to_bf8x8(uint32_t a) static constexpr uint32_t reg2 = 0X46444000; // register values [15, 14, 13, 12] static constexpr uint32_t reg3 = 0X4b4a4948; +#endif uint32_t tmp_pos, tmp_neg, tmp_res_even, tmp_res_odd, final_sel; @@ -370,6 +392,23 @@ struct PassThrough } }; +struct AddScale +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const + { + // Start with the base value c + float result = ck_tile::type_convert(0.0f); + + // Add by each D parameter using fold expression + ((result += ck_tile::type_convert(as)), ...); + + a = ck_tile::type_convert(scale * result); + } + + float scale = 1.0; +}; + struct MultiDMultiply { template diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 47de9af3b5..2861a7c216 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -31,8 +31,8 @@ struct GetDataType using type = typename T::DataType; // Use T::ScaleN::DataType }; -template struct CShuffleEpilogueProblem { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; static constexpr index_t kBlockSize = MWave_ * NWave_ * (kNumWaveGroups_ > 1 ? KWave_ : 1) * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t MWave = MWave_; @@ -88,12 +89,27 @@ template struct CShuffleEpilogue { using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; + + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using ATypeToUse = std::conditional_t, BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 54becd3c0f..2843966cd7 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -28,8 +28,8 @@ struct Default2DEpilogueProblem static constexpr index_t NumDTensor = 0; }; -template { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CLayout = remove_cvref_t; using DsDataType = remove_cvref_t; using CDElementwise = remove_cvref_t; @@ -157,14 +157,28 @@ struct Default2DEpilogue template struct DefaultGemm2DEpilogue : public Default2DEpilogue { - using Problem = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; + using Problem = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; + using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/block/block_masking.hpp b/include/ck_tile/ops/fmha/block/block_masking.hpp index f5c12e11d2..2c45945fac 100644 --- a/include/ck_tile/ops/fmha/block/block_masking.hpp +++ b/include/ck_tile/ops/fmha/block/block_masking.hpp @@ -203,27 +203,36 @@ struct GenericAttentionMask CK_TILE_HOST_DEVICE constexpr auto IsEdgeTile(index_t i_tile_top, index_t i_tile_left, number, number) const { - if constexpr(IsLocal) + if constexpr(!IsMasking) { - // check top-right corner > x or left-borrom corner < x - index_t i_tile_right = i_tile_left + TileWidth; - index_t i_tile_bottom = i_tile_top + TileHeight; - index_t x_end = min(i_tile_top + x, x_total); - - bool top_right_edge = i_tile_right > (i_tile_top + x); - bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); - bool is_partial_out_of_bound = i_tile_right > x_end; // only consider right-pad for now - - return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + // TODO: no need to check begin + return (i_tile_left + TileWidth) > x_total; } else { - // only need to check top-right corner > x - index_t i_tile_right = i_tile_left + TileWidth; - index_t x_end = min(i_tile_top + x, x_total); + if constexpr(IsLocal) + { + // check top-right corner > x or left-borrom corner < x + index_t i_tile_right = i_tile_left + TileWidth; + index_t i_tile_bottom = i_tile_top + TileHeight; + index_t x_end = min(i_tile_top + x, x_total); - bool top_right_edge = i_tile_right > x_end; - return top_right_edge; + bool top_right_edge = i_tile_right > (i_tile_top + x); + bool bottom_left_edge = i_tile_bottom > (i_tile_left + y); + bool is_partial_out_of_bound = + i_tile_right > x_end; // only consider right-pad for now + + return top_right_edge || bottom_left_edge || is_partial_out_of_bound; + } + else + { + // only need to check top-right corner > x + index_t i_tile_right = i_tile_left + TileWidth; + index_t x_end = min(i_tile_top + x, x_total); + + bool top_right_edge = i_tile_right > x_end; + return top_right_edge; + } } } diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 9d848dfd7a..58fdad149a 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -1446,29 +1446,35 @@ struct FmhaFwdKernel auto o_acc_tile = [&]() { if constexpr(kDoFp8StaticQuant) { - return FmhaPipeline{}( - q_dram_window, - identity{}, // q_element_func - k_dram_window, - identity{}, // k_element_func - v_dram_window, - identity{}, // v_element_func - bias_dram_window, - identity{}, // bias_element_func - randval_dram_window, - lse_dram_window, - identity{}, // lse_element_func - identity{}, // s_acc_element_func - scales{kargs.scale_p}, // p_compute_element_func - composes(saturates{}, scales{kargs.scale_o}), // o_acc_element_func - mask, - position_encoding, - kargs.scale_s, - variant, - variant_params, - block_indices, - smem_ptr, - dropout); + auto o_acc_element_func = [&]() { + if constexpr(std::is_same_v) + return ck_tile::composes(ck_tile::saturates{}, + ck_tile::scales{kargs.scale_o}); + else + return ck_tile::scales{kargs.scale_o}; + }(); + return FmhaPipeline{}(q_dram_window, + identity{}, // q_element_func + k_dram_window, + identity{}, // k_element_func + v_dram_window, + identity{}, // v_element_func + bias_dram_window, + identity{}, // bias_element_func + randval_dram_window, + lse_dram_window, + identity{}, // lse_element_func + identity{}, // s_acc_element_func + scales{kargs.scale_p}, // p_compute_element_func + o_acc_element_func, // o_acc_element_func + mask, + position_encoding, + kargs.scale_s, + variant, + variant_params, + block_indices, + smem_ptr, + dropout); } else { @@ -1868,7 +1874,7 @@ struct FmhaFwdKernel const auto v_dram_naive = make_naive_tensor_view( data, // will update this pointer if using paged-kvcache make_tuple(length, kargs.hdim_v), - make_tuple(kargs.hdim_v, 1), + make_tuple(kargs.stride_v, 1), number{}, number<1>{}); diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp index 87021354aa..c5e5745817 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_v3_kernel.hpp @@ -81,6 +81,7 @@ struct FmhaFwdV3Kernel // ck_tile::index_t window_size_left, window_size_right; ck_tile::index_t window_size_left, window_size_right; ck_tile::GenericAttentionMaskEnum mask_type; + ck_tile::index_t remap_opt; }; struct FmhaFwdCommonLSEKargs @@ -143,7 +144,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt) { Kargs kargs{{q_ptr, k_ptr, @@ -176,6 +178,7 @@ struct FmhaFwdV3Kernel kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); + kargs.remap_opt = remap_opt; } if constexpr(kStoreLSE) { @@ -213,7 +216,8 @@ struct FmhaFwdV3Kernel ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, - ck_tile::index_t mask_type) + ck_tile::index_t mask_type, + ck_tile::index_t remap_opt) { Kargs kargs{{q_ptr, k_ptr, @@ -245,6 +249,7 @@ struct FmhaFwdV3Kernel kargs.window_size_left = window_size_left; kargs.window_size_right = window_size_right; kargs.mask_type = static_cast(mask_type); + kargs.remap_opt = remap_opt; } if constexpr(kStoreLSE) { @@ -261,39 +266,81 @@ struct FmhaFwdV3Kernel ck_tile::index_t hdim_v_) { // TODO: this may need tuning - return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * - ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), - nhead_, - batch_size_); - } - - CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) - { - using namespace ck_tile; - - // const index_t num_tile_m0 = seqlen_q / kM0; - const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1); - - const index_t i_block = blockIdx.x; - const index_t i_nhead = blockIdx.y; - const index_t i_batch = blockIdx.z; - - const auto f = [](index_t dividend, index_t divisor) { - index_t quotient = dividend / divisor; - index_t modulus = dividend - quotient * divisor; - return ck_tile::make_tuple(quotient, modulus); - }; - - const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); - if constexpr(kHasMask) { - // assume that num_tile_n1 is always 1 - return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch); + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + batch_size_); } else { - return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + return dim3(nhead_, + ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1), + batch_size_); + } + } + + CK_TILE_DEVICE static constexpr auto + RemapTileIndices(int32_t tg_idx, int32_t tg_idy, int32_t remap_option) + { + if(remap_option < 1) + { + return make_tuple(static_cast(gridDim.x - tg_idx - 1), tg_idy); + } + + int32_t remapped_tg_idx = tg_idx; + int32_t remapped_tg_idy = tg_idy; + + if(remap_option == 2) + { // special remapping + int32_t tmp0 = (remapped_tg_idy & 0x7) * gridDim.x + remapped_tg_idx; + int32_t tmp1 = tmp0 & 0x7; + + remapped_tg_idx = tmp0 >> 3; + remapped_tg_idy = (remapped_tg_idy & 0xfffffff8) + tmp1; + } + else + { // normal remapping + int32_t cus_per_xdim_per_xcc = gridDim.x >> 3; + int32_t tgs_cu_id = remapped_tg_idx >> 3; + + if(tgs_cu_id < cus_per_xdim_per_xcc) + { + int32_t tgs_xcc_id = remapped_tg_idx & 0x7; + int32_t new_tg_idx = tgs_xcc_id * cus_per_xdim_per_xcc + tgs_cu_id; + + remapped_tg_idx = new_tg_idx; + } + } + + return make_tuple(remapped_tg_idx, remapped_tg_idy); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs&) + { + using namespace ck_tile; + + // const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, + // FmhaPipeline::kN1); + + // assume that num_tile_n1 is always 1 + if constexpr(kHasMask) + { + const index_t i_nhead = blockIdx.x; + const index_t i_block = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(gridDim.y - 1 - i_block, 0, i_nhead, i_batch); + } + else + { + const index_t i_nhead = blockIdx.x; + const index_t i_block = blockIdx.y; + const index_t i_batch = blockIdx.z; + + return ck_tile::make_tuple(i_block, 0, i_nhead, i_batch); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index b883aad155..c402eaeac4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -559,6 +559,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds on the previous + // iteration to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -814,6 +817,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto shuffled_bias_tile = make_static_distributed_tensor( Policy::template MakeShuffledBiasTileDistribution()); shuffle_tile(shuffled_bias_tile, bias_tile); + // SGrad and Bias use the same address in LDS, finish loading ds in the hot loop to + // reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, shuffled_bias_tile); block_sync_lds(); auto bias_s_tile = load_tile(bias_s_lds_read_window); @@ -956,6 +962,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return cast_tile(ds); } }(); + // Finish loading bias_s to reuse LDS. + block_sync_lds(); store_tile(bias_lds_write_window, dbias); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); @@ -975,11 +983,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); - if constexpr(kHasBiasGrad) - { - // SGrad and BiasGrad use the same address in LDS. - block_sync_lds(); - } + // SGrad and Bias/BiasGrad use the same address in LDS, finish loading bias/dbias or, when + // bias is not used, loading ds in the hot loop to reuse LDS. + block_sync_lds(); store_tile(ds_lds_window, ds_gemm); block_sync_lds(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 81950bd30a..41cb4fc306 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -698,6 +698,12 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } s_waitcnt(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 16d9f695df..6d90429407 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -489,7 +489,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR move_tile_window(k_dram_window, {kN0, 0}); async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); - // __builtin_amdgcn_s_waitcnt(0); + s_waitcnt(); k_reg_tensor = load_tile(k_lds_read_window); v_reg_tensor = load_tile(v_lds_read_window); kt_reg_tensor = load_tile_transpose(kt_lds_read_window); @@ -636,7 +636,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR } }(); store_tile(bias_lds_write_window, dbias); - __builtin_amdgcn_s_waitcnt(3952); + s_waitcnt(); block_sync_lds(); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window); auto dbias_tile = make_static_distributed_tensor( @@ -656,9 +656,15 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); dk_acc = gemm_3(dst_reg_tensor, qt_reg_tensor); + if constexpr(kHasBiasGrad) + { + // SGrad and BiasGrad use the same address in LDS, finish loading dbias to reuse + // LDS. + block_sync_lds(); + } store_tile(ds_lds_window, ds_gemm); } - __builtin_amdgcn_s_waitcnt(3952); + s_waitcnt(); block_sync_lds(); if constexpr(is_epilogue) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 68ead7c765..ad9e2959f5 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -1941,7 +1941,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt; constexpr index_t smem_size_stage0_1 = smem_size_v; - constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot + + constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot + smem_size_do + smem_size_lse + smem_size_d + max(smem_size_bias, smem_size_ds); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index 20d84116d4..5e2a4e898b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -57,7 +57,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 1) {} + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 2) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -68,11 +72,19 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU }); } - else if constexpr(Phase == 3) {} + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } } else { - if constexpr(Phase == 0) {} + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 1) { static_for<0, 8, 1>{}([&](auto) { @@ -81,7 +93,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 2) {} + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 3) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -115,7 +131,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 1) {} + else if constexpr(Phase == 1) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 2) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -126,11 +146,19 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU }); } - else if constexpr(Phase == 3) {} + else if constexpr(Phase == 3) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } } else { - if constexpr(Phase == 0) {} + if constexpr(Phase == 0) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 1) { static_for<0, 8, 1>{}([&](auto) { @@ -139,7 +167,11 @@ struct CoreLoopScheduler __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU }); } - else if constexpr(Phase == 2) {} + else if constexpr(Phase == 2) + { + __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU + __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU + } else if constexpr(Phase == 3) { #if !CK_TILE_DISABLE_PACKED_FP32 @@ -177,6 +209,15 @@ CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs) return result; } +CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs) +{ + float result; + asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]" + : [result] "=v"(result) + : [lhs] "v"(lhs), [rhs] "v"(rhs)); + return result; +} + CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b) { fp16x2_t result; @@ -466,7 +507,7 @@ struct BlockFmhaFwdV3Pipeline statically_indexed_array sp; decltype(gemm_1.MakeCBlockTile()) o_acc; - constexpr index_t fmha_alu_D_reg_cnt = 0; // threshold to decide how many fmha_alu_D_upd() + constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd() // instructions should we move to fmha_alu1() static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size()); @@ -631,8 +672,8 @@ struct BlockFmhaFwdV3Pipeline // K_mem_su_ld_insts = 1 for 32 x 128 // V_mem_su_ld_insts = 1 for 128 x 32 - static constexpr int K_mem_su_ld_insts = 1; - static constexpr int V_mem_su_ld_insts = 1; + constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access(); + constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access(); auto K_mem_load = [&](auto k_lds_write_idx) { async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window); @@ -648,7 +689,6 @@ struct BlockFmhaFwdV3Pipeline auto V_mem_load = [&](auto v_lds_write_idx) { async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window); - __builtin_amdgcn_sched_barrier(0); /// FIXME: use the future-predicting method to move the window move_tile_window(v_dram_window, {kK1, 0}); @@ -726,11 +766,12 @@ struct BlockFmhaFwdV3Pipeline #else block_tile_reduce_sync(rowsum_p, f_sum, bool_constant{}); #endif - // update partial o_acc [0, 2) - static_for<0, ck_tile::min(2, fmha_alu_D_reg_cnt), 1>{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); // l{j} + /// Note: The compiler keeps moving the following instructions elsewhere because 'l' + /// is first consumed later. To anchor them here, we rewrite the final addition in + /// inline assembly to create a dependency, forcing the dependent instructions to + /// be emitted at this point. constexpr auto o_spans = decltype(o_acc)::get_distributed_spans(); sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) { constexpr auto i_idx = make_tuple(idx0); @@ -739,13 +780,15 @@ struct BlockFmhaFwdV3Pipeline l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]); }); - // update partial o_acc [2, fmha_alu_D_reg_cnt) - static_for<2, ck_tile::max(2, fmha_alu_D_reg_cnt), 1>{}( - [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; }); + // update partial o_acc [0, fmha_alu_D_reg_cnt) + static_for<0, fmha_alu_D_reg_cnt, 1>{}([&](auto idx) { + o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale); + }); - /// NOTICE: Compiler keep moving the conversion instructions to other places. We rewite - /// the cast_tile() call into inline asm to force the conversion instructions to be - /// generated here. The fmha_alu1() call should be placed at the end of a phase. + /// Note: The compiler keeps sinking the conversion instructions because the + /// result 'p' is only consumed later. To anchor them here, we rewrite + /// the cast_tile() call as inline assembly, forcing the conversions to be + /// emitted at this point. static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0); static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) { float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]); @@ -763,6 +806,10 @@ struct BlockFmhaFwdV3Pipeline sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y; } }); + + /// Note: Place fmha_alu1() at the end of the phase. The surrounding inline assembly + /// can interfere with the behavior of sched_group_barrier(), so ending the phase here + /// avoids unintended reordering. }; auto gemm = [&](auto sp_reg_idx, auto gemm_idx) { @@ -937,9 +984,9 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx); + Scheduler::schedule(cl_p, number<1>{}); fmha_mask(xdl_SP_p01_reg_idx); - Scheduler::schedule(cl_p, number<1>{}); __builtin_amdgcn_sched_barrier(0); // phase2 ASM_MARKER("phase2 Wave0-3"); @@ -947,6 +994,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 0"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); Scheduler::schedule(cl_p, number<2>{}); @@ -995,6 +1044,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p01_reg_idx, gemm0); fmha_alu1(xdl_SP_p23_reg_idx); @@ -1005,9 +1056,9 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx); + Scheduler::schedule(cl_p, number<2>{}); fmha_mask(xdl_SP_p01_reg_idx); - Scheduler::schedule(cl_p, number<2>{}); kv_token_start += kN0; if(num_total_loop <= ++i_total_loops) { @@ -1021,6 +1072,8 @@ struct BlockFmhaFwdV3Pipeline __builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_s_barrier(); __builtin_amdgcn_sched_barrier(0); + asm volatile("s_nop 1"); + __builtin_amdgcn_sched_barrier(0); cl_calc(xdl_SP_p23_reg_idx, gemm1); Scheduler::schedule(cl_p, number<3>{}); @@ -1036,7 +1089,14 @@ struct BlockFmhaFwdV3Pipeline auto ps_pi = number<1>{} - d; auto V_lds_rd_idx = ps_pi; - s_waitcnt_vmcnt(); + if(1 < num_total_loop) + { + s_waitcnt_vmcnt(); + } + else + { + s_waitcnt_vmcnt<0>(); + } __builtin_amdgcn_s_barrier(); V_lds_load(V_lds_rd_idx); @@ -1102,14 +1162,14 @@ struct BlockFmhaFwdV3Pipeline V_mem_load(number<1>{}); // V1 K_lds_load(number<1>{}); // K1 - asm volatile("s_setprio 0"); + __builtin_amdgcn_s_setprio(0); __builtin_amdgcn_s_barrier(); while(core_loop(number<0>{})) ; } if(warp_group_id != 0) { - asm volatile("s_setprio 1"); + __builtin_amdgcn_s_setprio(1); __builtin_amdgcn_s_barrier(); while(core_loop(number<1>{})) ; @@ -1167,14 +1227,13 @@ struct BlockFmhaFwdV3Pipeline typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename LSEDramBlockWindowTmp> - CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile - FmhaMask mask, - float scale_s, - void* smem_ptr) const + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile + FmhaMask mask, + float scale_s, + void* smem_ptr) const { using namespace ck_tile; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index de13e305e0..6e07dbc00e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -31,6 +31,7 @@ #include "ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index e1b0792ecf..94adb42880 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -13,7 +14,9 @@ namespace ck_tile { // A is block window on shared memory // B is block window on shared memory // C is block distributed tensor -template +template struct BlockUniversalGemmAsBsCr { private: @@ -91,6 +94,7 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -179,25 +183,6 @@ struct BlockUniversalGemmAsBsCr return b_block_dstr_encode; } - private: - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - constexpr index_t UnaryOpSize = 8; - const element_wise::PassThroughPack8 elementwise_op{}; - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } - template struct BlockGemmImpl { @@ -239,7 +224,7 @@ struct BlockUniversalGemmAsBsCr if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -247,7 +232,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { @@ -317,7 +302,7 @@ struct BlockUniversalGemmAsBsCr { if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else if constexpr(ALoadTranspose) { @@ -329,7 +314,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else if constexpr(BLoadTranspose) { @@ -468,7 +453,7 @@ struct BlockUniversalGemmAsBsCr if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else if constexpr(ALoadTranspose) { @@ -480,7 +465,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else if constexpr(BLoadTranspose) { diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index fcfbf9635f..588d903b25 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -90,10 +90,10 @@ struct BatchedGemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); struct BatchedGemmKernelArgs : ck_tile::UniversalGemmKernelArgs<> { diff --git a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp index e37b4f36d4..d632b1596c 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp @@ -89,7 +89,7 @@ struct GemmKernel /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D using ADataType = remove_cvref_t; @@ -106,10 +106,10 @@ struct GemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); static constexpr index_t NumATensor = 1; static constexpr index_t NumBTensor = 1; diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp new file mode 100644 index 0000000000..3b050e03ed --- /dev/null +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp @@ -0,0 +1,193 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/host/concat.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/stream_utils.hpp" +#include "ck_tile/core/utility/env.hpp" +#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp" +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +/// @brief The MultiABD GEMM kernel host arguments. +/// +/// @par Overview +/// This structure is passed to @ref GemmKernelMultiABD "GemmKernelMultiABD" when creating +/// kernel arguments object. It contain all necessary information required to build proper +/// kernel argument and launch kernel on GPU. This structure defines the GEMM problem +/// configuration by stating all required information like M,N,K sizes and respective strides. +/// NumATensor describes the number of A tensors. The minimum number of tensors is 1(required). +/// NumBTensor describes the number of B tensors. The minimum number of tensors is 1(required). +/// NumDTensor describes the number of D tensors. The minimum number of tensors is 0(not +/// required). +template +struct GemmMultiABDHostArgs +{ + CK_TILE_HOST GemmMultiABDHostArgs(const std::array& as_ptr_, + const std::array& bs_ptr_, + const std::array& ds_ptr_, + void* e_ptr_, + index_t k_batch_, + index_t M_, + index_t N_, + index_t K_, + const std::array& stride_As_, + const std::array& stride_Bs_, + const std::array& stride_Ds_, + index_t stride_E_) + : as_ptr(as_ptr_), + bs_ptr(bs_ptr_), + ds_ptr(ds_ptr_), + e_ptr(e_ptr_), + M(M_), + N(N_), + K(K_), + stride_As(stride_As_), + stride_Bs(stride_Bs_), + stride_Ds(stride_Ds_), + stride_E(stride_E_), + k_batch(k_batch_) + { + } + + const std::array as_ptr; + const std::array bs_ptr; + const std::array ds_ptr; + union + { + void* e_ptr; + void* c_ptr; + }; + index_t M; + index_t N; + index_t K; + const std::array stride_As; + const std::array stride_Bs; + const std::array stride_Ds; + union + { + index_t stride_E; + index_t stride_C; + }; + + index_t k_batch; +}; + +template +struct GemmKernelMultiABD +{ + /// @brief Inject the UniversalGemmKernel base class to support execution of all necessary + /// functions. + using UniversalGemmKernel = + UniversalGemmKernel; + static constexpr index_t kBlockSize = UniversalGemmKernel::kBlockSize; + + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + + /// @brief Specify the layout configurations for A, B, E and D + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + using DsLayout = remove_cvref_t; + + /// @brief Specify the data type configurations for A, B, E and D + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using EDataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + + /// @brief ALayout and ADataType are expected to be a tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value, + "ALayout and ADataType must be a tuple."); + + /// @brief BLayout and BDataType are expected to be a tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value, + "BLayout and BDataType must be a tuple."); + + /// @brief CLayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && + !is_detected::value, + "CLayout and EDataType must be a scalar."); + + /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. + static_assert(is_detected::value && + is_detected::value && + DsLayout::size() == DsDataType::size() && DsLayout::size() > 0, + "DsLayout and DsDataType must be tuples and must have the same size."); + + /// @brief The sizes of NumATensor, NumBTensor and NumDTensor is set by the user." + static constexpr index_t NumATensor = AsDataType::size(); + static constexpr index_t NumBTensor = BsDataType::size(); + static constexpr index_t NumDTensor = DsDataType::size(); + + CK_TILE_HOST static auto GetName() -> const std::string + { + return UniversalGemmKernel::GetName(); + } + + CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) -> dim3 + { + return UniversalGemmKernel::GridSize(M, N, KBatch); + } + + CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3 + { + return UniversalGemmKernel::MaxOccupancyGridSize(s); + } + + CK_TILE_HOST static constexpr auto BlockSize() -> dim3 + { + return UniversalGemmKernel::BlockSize(); + } + + CK_TILE_HOST static constexpr auto + MakeKernelArgs(const GemmMultiABDHostArgs& hostArgs) -> + typename UniversalGemmKernel::KernelArgs + { + /// @brief Universal GEMM requires array objects and corresponding stride information for + /// matrices A, B, and D. + return UniversalGemmKernel::MakeKernelArgs( + UniversalGemmHostArgs(hostArgs.as_ptr, + hostArgs.bs_ptr, + hostArgs.ds_ptr, + hostArgs.e_ptr, + hostArgs.k_batch, + hostArgs.M, + hostArgs.N, + hostArgs.K, + hostArgs.stride_As, + hostArgs.stride_Bs, + hostArgs.stride_Ds, + hostArgs.stride_E)); + } + + CK_TILE_HOST static auto + IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) -> bool + { + // Currently MultiABD kernel doesn't support k_batch > 1 + if(kargs.k_batch > 1) + { + return false; + } + + return UniversalGemmKernel::IsSupportedArgument(kargs); + } + + CK_TILE_DEVICE auto operator()(typename UniversalGemmKernel::KernelArgs kargs) const -> void + { + UniversalGemmKernel{}.template operator()(kargs); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp index 9d3ac8b901..b0b2905cb4 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_multi_d_kernel.hpp @@ -95,7 +95,7 @@ struct GemmKernelMultiD /// @brief Specify the layout configurations for A, B, E and D using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; using DsLayout = remove_cvref_t; /// @brief Specify the data type configurations for A, B, E and D @@ -114,10 +114,10 @@ struct GemmKernelMultiD !is_detected::value, "BLayout and BDataType must be scalars."); - /// @brief ELayout and EDataType are expected to be scalars, not a tuple. - static_assert(!is_detected::value && + /// @brief CLayout and EDataType are expected to be scalars, not a tuple. + static_assert(!is_detected::value && !is_detected::value, - "ELayout and EDataType must be scalars."); + "CLayout and EDataType must be scalars."); /// @brief DsLayout and DsDataType are expected to be tuple, not a scalar. static_assert(is_detected::value && diff --git a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp index 92ae6411a5..a891d4df55 100644 --- a/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp @@ -646,16 +646,13 @@ struct StreamKTilePartitioner * @brief Get length of loop iterations for stream-k loop */ CK_TILE_DEVICE uint32_t GetCurrentIterLength(uint32_t iter_start, - uint32_t iter_end, - uint32_t total_iter_length) const noexcept + uint32_t iter_end) const noexcept { - uint32_t iter_length_mod, iter_length_quo /*unused*/; - k_iters_per_tile.divmod(iter_end, iter_length_quo, iter_length_mod); - uint32_t total_iter_length_val = static_cast(total_iter_length); - uint32_t current_iter_length = - min(iter_length_mod == 0 ? (iter_end - iter_start) : iter_length_mod, - total_iter_length_val); - return current_iter_length; + // A WG's iter_end is either in the current C macro tile or not. + // If it is not, then the macro tile boundary is where the WG must stop. + uint32_t distance_to_tile_boundary = + k_iters_per_tile.get() - (iter_start % k_iters_per_tile.get()); + return min(iter_start + distance_to_tile_boundary, iter_end) - iter_start; } /** @@ -672,9 +669,7 @@ struct StreamKTilePartitioner CK_TILE_DEVICE void GetTileIdxWithOffset(uint32_t iter, uint32_t& tile_idx, uint32_t& iter_offset) const noexcept { - uint32_t tile_idx_val = static_cast(tile_idx); - uint32_t iter_offset_val = static_cast(iter_offset); - k_iters_per_tile.divmod(iter, tile_idx_val, iter_offset_val); + k_iters_per_tile.divmod(iter, tile_idx, iter_offset); } /** diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 704d0d01ee..df1d6c9e4f 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -120,10 +120,10 @@ struct GroupedGemmKernel !is_detected::value && !is_detected::value, "BLayout and BDataType must be scalars. Multiple parameters are not currently supported."); - /// @brief C/ELayout and C/EDataType are expected to be scalars, not a tuple. + /// @brief C/CLayout and C/EDataType are expected to be scalars, not a tuple. static_assert(!is_detected::value && !is_detected::value, - "C/ELayout and C/EDataType must be scalars."); + "C/CLayout and C/EDataType must be scalars."); using OffsetTile1DPartitioner = OffsettedTile1DPartitioner; using Kernel = GroupedGemmKernel; @@ -292,34 +292,8 @@ struct GroupedGemmKernel { __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle) - { - - RunGemmWithPipelineSelection2LDS(a_ptr, - b_ptr, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - return; - } - else - { - - Base::RunGemm2LDS({a_ptr}, - {b_ptr}, - {/*ds_ptr*/}, - c_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemmWithPipelineSelection2LDS( + a_ptr, b_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); } else // SingleSmemBuffer { @@ -374,7 +348,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = @@ -390,12 +364,8 @@ struct GroupedGemmKernel const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(Base::I3); EpiloguePipeline{}.template @@ -436,7 +406,7 @@ struct GroupedGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset); + {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp index 77c431e49c..5df1f092d7 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm_kernel.hpp @@ -141,11 +141,17 @@ struct StreamKKernel return UniversalGemmKernel::BlockSize(); } - CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args) + /// @brief Constructs kernel arguments for the Stream-K kernel. + /// @param host_args Stream-K host arguments. + /// @param num_cu Number of compute units (CUs). The default is the number of CUs on the device. + /// The caller may select their own to assist with test reproducibility, etc. + /// @param occupancy The maximum number of active blocks per CU for this kernel. The caller may + /// select their own to assist with test reproducibility, etc. + /// @return The kernel arguments for Stream-K. + CK_TILE_HOST static StreamKKernelArgs MakeKernelArgs(const StreamKHostArgs& host_args, + int num_cu = NumCU(), + int occupancy = Occupancy()) { - uint32_t occupancy = static_cast(Occupancy()); - uint32_t num_cu = static_cast(NumCU()); - return StreamKKernelArgs{{host_args.as_ptr, host_args.bs_ptr, host_args.ds_ptr, @@ -166,14 +172,71 @@ struct StreamKKernel TilePartitioner{static_cast(host_args.M), static_cast(host_args.N), static_cast(host_args.K), - num_cu, - occupancy, + static_cast(num_cu), + static_cast(occupancy), host_args.num_sk_blocks}}; } - CK_TILE_HOST static bool - IsSupportedArgument(const typename UniversalGemmKernel::KernelArgs& kargs) + template + CK_TILE_DEVICE static void + RunGemm(const std::array& as_ptr, + const std::array& bs_ptr, + const std::array& ds_ptr, + CDataType* c_ptr, + void* smem_ptr_0, + const typename UniversalGemmKernel::KernelArgs& kargs, + const index_t num_loop, + const index_t block_idx_m, + const index_t block_idx_n, + const index_t k_size) { + // Create Gemm tensor views, pad views and tile windows + const auto& gemm_tensor_views_tuple = + UniversalGemmKernel::template MakeGemmTensorViews( + as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); + + const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); + auto gemm_tile_windows = + UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + + // Run GEMM cooperatively by whole workgroup. + const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); + const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); + const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + + // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute + // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this + // case, we call the GemmPipeline's operator() function that takes both has_hot_loop and + // tail_num. + const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); + const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], + bs_block_window[UniversalGemmKernel::I0], + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + + if(UseDefaultScheduler || (get_warp_id() == 0)) + { + // Run Epilogue Pipeline + auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + } + + CK_TILE_HOST static bool IsSupportedArgument(const StreamKKernelArgs& kargs) + { + if(kargs.reduction_strategy == StreamKReductionStrategy::Reduction) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("CK Tile Stream-K only supports the atomic reduction strategy."); + } + return false; + } return UniversalGemmKernel::IsSupportedArgument(kargs); } @@ -199,9 +262,81 @@ struct StreamKKernel kargs.workspace_ptr = workspace_ptr; } - // Temporary placeholder to support the Occupancy() static function. - // Since the Occupancy function uses kentry, this class must have an operator() function - CK_TILE_DEVICE void operator()(StreamKKernelArgs /*kargs*/) const {} + /// @brief Entry point for the Stream-K Kernel, performing the main Stream-K loop. + CK_TILE_DEVICE void operator()(StreamKKernelArgs kargs) const + { + // Allocate LDS + __shared__ char smem_ptr_0[UniversalGemmKernel::GetSmemSize()]; + + uint32_t block_idx = ck_tile::get_block_1d_id(); + + bool is_padding_block = + __builtin_amdgcn_readfirstlane(block_idx >= kargs.tile_partitioner.sk_num_blocks && + block_idx < kargs.tile_partitioner.dp_start_block_idx); + + // Padding blocks make it such that the DP blocks are aligned with the number of CUs; they + // should not partake in the GEMM + if(is_padding_block) + return; + + // Determine the K offset of the first and final macro tile in the A and B tensors along the + // K dimension. + uint32_t iter_start, iter_end; + kargs.tile_partitioner.GetBlockItr(block_idx, iter_start, iter_end); + + // Main Stream-K loop + while(true) + { + // Determine the number of macro tiles in A and B this WG is resposible for in the + // current C macro tile. + uint32_t current_iter_length = __builtin_amdgcn_readfirstlane( + kargs.tile_partitioner.GetCurrentIterLength(iter_start, iter_end)); + + // Determine the 1D tile_idx and the iter_offset for this WG. + // The tile_idx is the 1D macro tile index in the C tensor. + // The iter_offset is the starting macro tile index in the K dimension for the WG in the + // current iteration of the while loop. + uint32_t tile_idx, iter_offset; + kargs.tile_partitioner.GetTileIdxWithOffset(iter_start, tile_idx, iter_offset); + + // Get the 2D tile index in the C tensor for this WG using the 1D index (i.e. tile_idx) + auto spatial_idx = kargs.tile_partitioner.GetOutputTileIndex(tile_idx); + + // Get the offsets in A, B, C tensors. + index_t i_m = static_cast(spatial_idx[UniversalGemmKernel::I0] * + TilePartitioner::MPerBlock); + index_t i_n = static_cast(spatial_idx[UniversalGemmKernel::I1] * + TilePartitioner::NPerBlock); + index_t i_k = static_cast(iter_offset) * TilePartitioner::KPerBlock; + + // Determine the total size along the K dimension the WG is using in this iteration + // (used to construct tensor views). + index_t k_size = static_cast(current_iter_length * TilePartitioner::KPerBlock); + + // Update pointer offsets for A, B, and C. + const ADataType* a_ptr = static_cast(kargs.as_ptr[0]) + i_k; + const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k; + CDataType* c_ptr = static_cast(kargs.e_ptr); + + // Run the GEMM pipeline and Epilogue. + RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr_0, + kargs, + current_iter_length, + i_m, + i_n, + k_size); + + // Prepare for next Stream-K loop iteration. + iter_start += current_iter_length; + if(iter_end <= iter_start) + break; + block_sync_lds(); + } + } private: CK_TILE_HOST static int NumCU() diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 8117d65758..8f44108cc4 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -157,23 +157,23 @@ struct UniversalGemmKernel using EpiloguePipeline = remove_cvref_t; static constexpr bool ADataTypeIsTuple = - is_detected::value; + is_detected::value; static constexpr bool BDataTypeIsTuple = - is_detected::value; + is_detected::value; static constexpr bool DDataTypeIsTuple = is_detected::value; static constexpr bool ALayoutIsTuple = - is_detected::value; + is_detected::value; static constexpr bool BLayoutIsTuple = - is_detected::value; + is_detected::value; static constexpr bool DLayoutIsTuple = is_detected::value; using AsLayout = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using BsLayout = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using DsLayout = std::conditional_t>>; using AsDataType = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using BsDataType = std::conditional_t, + remove_cvref_t, remove_cvref_t>>; using DsDataType = @@ -193,9 +193,12 @@ struct UniversalGemmKernel remove_cvref_t, remove_cvref_t>>; - using ELayout = remove_cvref_t; + using CLayout = remove_cvref_t; using EDataType = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + static constexpr index_t kBlockSize = GemmPipeline::BlockSize; // Get the persistent kernel if the pipeline has it available @@ -483,7 +486,7 @@ struct UniversalGemmKernel bool DTesnorIsValid = {true}; static_for<0, NumDTensor, 1>{}([&](auto index) { using DiLayout = remove_cvref_t>; - if(std::is_same_v == false) + if(std::is_same_v == false) { DTesnorIsValid = false; } @@ -529,7 +532,7 @@ struct UniversalGemmKernel } }); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false) { @@ -579,7 +582,7 @@ struct UniversalGemmKernel const std::array& ds_ptr, EDataType* e_ptr, const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + const index_t k_size) { static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!"); @@ -591,7 +594,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -600,7 +603,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( static_cast(as_ptr[i]), - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_As[i], 1), number{}, number<1>{}); @@ -617,7 +620,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -638,7 +641,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(splitk_batch_offset.splitted_k, kargs.N), + make_tuple(k_size, kargs.N), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -649,7 +652,7 @@ struct UniversalGemmKernel if constexpr(TilePartitioner::BlockGemmShape::PermuteB) { constexpr index_t K1 = GemmPipeline::GetSmemPackB(); - const index_t K0 = splitk_batch_offset.splitted_k / K1; + const index_t K0 = k_size / K1; constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); const auto b_k0_n_k1_desc = @@ -672,7 +675,7 @@ struct UniversalGemmKernel { index_t kFlatK = GemmPipeline::BlockGemmShape::flatKPerWarp * - (splitk_batch_offset.splitted_k / + (k_size / TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{})); index_t kFlatN = kargs.N * kargs.K / kFlatK; @@ -687,7 +690,7 @@ struct UniversalGemmKernel { return make_naive_tensor_view( bs_ptr[i], - make_tuple(kargs.N, splitk_batch_offset.splitted_k), + make_tuple(kargs.N, k_size), make_tuple(kargs.stride_Bs[i], 1), number{}, number<1>{}); @@ -724,7 +727,7 @@ struct UniversalGemmKernel // TODO: enable vector write for C in ColMajor const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return make_naive_tensor_view( e_ptr, @@ -818,7 +821,7 @@ struct UniversalGemmKernel // TODO vector write in for C in ColMajor const auto& e_pad_view = [&]() { const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, make_tuple(number{}, @@ -962,7 +965,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -975,8 +978,8 @@ struct UniversalGemmKernel const auto& bs_block_window = gemm_tile_windows.at(I1); const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = - GemmPipeline{}(as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0); + const auto& c_block_tile = GemmPipeline{}.template operator()( + as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); if(UseDefaultScheduler || (get_warp_id() == 0)) { @@ -1018,7 +1021,7 @@ struct UniversalGemmKernel // Create Gemm tensor views, pad views and tile windows const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); + as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); @@ -1031,8 +1034,13 @@ struct UniversalGemmKernel const auto& bs_block_window = gemm_tile_windows.at(I1); const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}( - as_block_window[I0], bs_block_window[I0], num_loop, smem_ptr_0, smem_ptr_1); + const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, + AElementWise{}, + bs_block_window, + BElementWise{}, + num_loop, + smem_ptr_0, + smem_ptr_1); // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(I3); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 2bee550b3c..b5584f98df 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -11,12 +11,17 @@ namespace ck_tile { template struct GemmPipelineAgBgCrImplBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using ADataType = remove_cvref_t{}, AsDataType>>; + using ALayout = remove_cvref_t{}, AsLayout>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + using BLayout = remove_cvref_t{}, BsLayout>>; + static constexpr index_t MPerBlock = BlockGemmShape::kM; static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t KPerBlock = BlockGemmShape::kK; @@ -57,6 +62,13 @@ struct GemmPipelineAgBgCrImplBase store_tile(lds_tile_window, block_tile_tmp); } + template + CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window, + const SrcBlockTile& src_block_tile) const + { + store_tile(lds_tile_window, src_block_tile); + } + template CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile, const SrcTileWindow& lds_tile_window, @@ -88,23 +100,100 @@ struct GemmPipelineAgBgCrImplBase return make_tuple(std::move(a_lds_block), std::move(b_lds_block)); } + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp[number{}].get_window_origin() + offset, + Policy::template MakeADramTileDistribution()); + }, + number{}); + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_col_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp.get_window_origin() + offset, + Policy::template MakeADramTileDistribution()); + + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp[number{}].get_window_origin() + offset, + Policy::template MakeBDramTileDistribution()); + }, + number{}); + return std::move(a_copy_dram_window); + } + + template ::value, bool>* = + nullptr> + CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp, + const array& offset = {0, 0}) const + { + constexpr bool is_row_major = std::is_same_v; + + using YPerTile = std::conditional_t, number>; + using XPerTile = std::conditional_t, number>; + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(YPerTile{}, XPerTile{}), + dram_block_window_tmp.get_window_origin() + offset, + Policy::template MakeBDramTileDistribution()); + + return std::move(a_copy_dram_window); + } + template CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp, const ALdsTensorView& a_lds_block_view, const ALdsLoadTileDistr&, const array& offset = {0, 0}) const { - constexpr bool is_col_major = std::is_same_v; - - using YPerTile = std::conditional_t, number>; - using XPerTile = std::conditional_t, number>; - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile{}, XPerTile{}), - a_dram_block_window_tmp.get_window_origin() + offset, - Policy::template MakeADramTileDistribution()); + auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset); // A LDS tile window for store auto a_lds_shape = []() { @@ -138,16 +227,8 @@ struct GemmPipelineAgBgCrImplBase const BLdsLoadTileDistr&, const array& offset = {0, 0}) const { - constexpr bool is_row_major = std::is_same_v; - - using YPerTile = std::conditional_t, number>; - using XPerTile = std::conditional_t, number>; - - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(YPerTile{}, XPerTile{}), - b_dram_block_window_tmp.get_window_origin() + offset, - Policy::template MakeBDramTileDistribution()); + // A DRAM tile window for load + auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset); // TODO: Do we really need those two tile windows??? // They're exactly same... diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 5f4ee8987e..7159eda683 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -107,14 +107,23 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using Base = BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; using I0 = number<0>; @@ -386,17 +395,25 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); - - using ABlockTile = - decltype(make_static_distributed_tensor(ABlockTileDistr{})); - using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); - - ABlockTile a_block_tile; - BBlockTile b_block_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -470,45 +476,61 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // ----------------------------------------------------------------------------------------- // Gemm pipeline start - - // prefetch - // global read 0 - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + // LDS write 0 if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // global read 1 + + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); __builtin_amdgcn_sched_barrier(0); @@ -520,38 +542,42 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major && !is_a_load_tr_v()) + if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } - if constexpr(is_b_row_major && !is_b_load_tr_v()) + if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } - Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); @@ -574,27 +600,26 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } block_sync_lds(); - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); } // __builtin_amdgcn_sched_barrier(0); @@ -602,13 +627,16 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const @@ -628,9 +656,13 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 * @note This is used by the persistent gemm kernel variants that don't determine * hot loop and tail number on the host side, e.g. grouped gemm kernel. */ - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -639,7 +671,7 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -658,20 +690,97 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 * @note This is used by the kernel variants that are able to determine * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. */ - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline by wrapping it with + * the tail handler. + * + * @note This is used by the persistent gemm kernel variants that don't determine + * hot loop and tail number on the host side, e.g. grouped gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + /** + * @brief Quant operator(), single input: This function runs the pipeline using compile-time + * known hot loop and tail number. + * @param num_loop The number of loop iterations. This is determined at runtime due to e.g. + * SplitK. + * @note This is used by the kernel variants that are able to determine + * hot loop and tail number on the host side, e.g. non-persistent gemm kernel. + */ + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index c835809b5d..b362f751c6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -97,11 +97,24 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 using Base = BaseGemmPipelineAgBgCrCompV4; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + static_assert(!std::is_same_v, "Not implemented"); static constexpr index_t APackedSize = @@ -109,10 +122,6 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; using I0 = number<0>; using I1 = number<1>; @@ -244,18 +253,26 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]), "B block window has incorrect lengths for defined BLayout!"); - ////////////// global window & register ///////////////// - // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); - - // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); - - // A register tile for global load - constexpr auto ABlockTileDistr = a_copy_dram_window.get_tile_distribution(); - constexpr auto BBlockTileDistr = b_copy_dram_window.get_tile_distribution(); - using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr)); - using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr)); - ABlockTile a_global_load_tile; - BBlockTile b_global_load_tile; - using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -312,8 +306,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // global prefetch 0 // global read 0 - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + ////////////// LDS desc, window & register ///////////////// auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); @@ -343,34 +336,75 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // Generating a tuple with tile_windows for values A0, A1, ... AN + auto a_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + // Generating a tuple with tile_windows for values B0, B1, ... BN + auto b_tile_windows = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_tile_windows, b_dram_tile_window_step); + // LDS write 0 if constexpr(is_a_col_major && !is_a_load_tr_v()) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } // global read 1 - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); block_sync_lds(); constexpr auto ALdsTileDistr = @@ -423,27 +457,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res); } - Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); if(HasHotLoop) { @@ -461,31 +500,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill( - a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill( - b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); // gemm block_gemm(c_block_tile, a_block_tile0, b_block_tile0); HotLoopScheduler(); @@ -501,32 +541,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp); } else { - Base::LocalPrefill( - a_copy_lds_window1, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp); } else { - Base::LocalPrefill( - b_copy_lds_window1, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res); } block_sync_lds(); - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + elementwise_As_res = + load_tile_with_elementwise(a_tile_windows, a_element_func); + move_tile_window(a_tile_windows, a_dram_tile_window_step); + + elementwise_Bs_res = + load_tile_with_elementwise(b_tile_windows, b_element_func); + move_tile_window(b_tile_windows, b_dram_tile_window_step); + // gemm block_gemm(c_block_tile, a_block_tile1, b_block_tile1); HotLoopScheduler(); @@ -548,23 +590,23 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res); } block_gemm(c_block_tile, a_block_tile0, b_block_tile0); } @@ -606,13 +648,17 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem_0, @@ -628,27 +674,34 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 p_smem_1); } - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, void* __restrict__ p_smem_0, void* __restrict__ p_smem_1) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem_0, p_smem_1); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -658,7 +711,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -670,5 +723,69 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 }; return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0, + void* p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem_0, + p_smem_1); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem_0, + p_smem_1); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* __restrict__ p_smem_0, + void* __restrict__ p_smem_1) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem_0, + p_smem_1); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp index b83d37a790..474d1a5a21 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5.hpp @@ -41,15 +41,24 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 using Base = BaseGemmPipelineAgBgCrCompV5; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; using CDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; @@ -121,17 +130,25 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BsDramBlockWindowTmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* __restrict__ p_smem_0) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v BGemmTile b_tile_0, b_tile_1; // Register tile for A and B. - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); - ABlockTile a_global_load_tile; - BBlockTile b_global_load_tile; + ABlockTile elementwise_As_res; + BBlockTile elementwise_Bs_res; // Block GEMM auto block_gemm = BlockGemm(); @@ -248,33 +267,45 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 // define ping, pong steps here as lambda functions. auto MemoryOpsStep = [&](auto idx) { // Memory read half here. - Base::GlobalPrefetch( - a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each + // A0, A1, … AN. The values A0, A1, … AN are read by the same thread. + elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each + // B0, B1, … BN. The values B0, B1, … BN are read by the same thread. + elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a + // tuple as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_global_load_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_global_load_tile, a_element_func); + Base::LocalPrefill(a_copy_lds_window, elementwise_As_res); } if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_global_load_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_global_load_tile, b_element_func); + Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res); } if(idx == 0) @@ -351,13 +382,17 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem_0) const @@ -371,21 +406,62 @@ struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5 p_smem_0); } - public: - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const index_t num_loop, void* __restrict__ p_smem_0) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const BDataType& b) { e = b; }, num_loop, p_smem_0); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_0) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem_0); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const index_t num_loop, + void* __restrict__ p_smem_0) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem_0); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index e1acfebc47..9e522d4364 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -157,14 +157,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem using Base = BaseGemmPipelineAgBgCrMem; using PipelineImplBase = GemmPipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; @@ -236,17 +245,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); @@ -334,10 +353,21 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - Base::GlobalPrefetch( - a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -348,32 +378,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); // main body @@ -397,14 +430,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( a_shuffle_tmp, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill( a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { @@ -413,22 +445,23 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( b_shuffle_tmp, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill( b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); i += PrefetchStages; @@ -450,26 +483,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + b_block_tiles.get(number{})); } }); @@ -526,17 +557,25 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); - using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + using ABlockTileDistr = + decltype(a_copy_dram_window[number<0>{}].get_tile_distribution()); + using BBlockTileDistr = + decltype(b_copy_dram_window[number<0>{}].get_tile_distribution()); using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); @@ -623,10 +664,22 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // prefetch // global read 0 - Base::GlobalPrefetch( - a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + a_block_tiles.at(I0{}) = load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + b_block_tiles.at(I0{}) = load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -637,32 +690,35 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(I0{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { - Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{}), a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_block_tiles.get(I0{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(I0{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{}), b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_block_tiles.get(I0{})); } // Global prefetch [1, PrefetchStages] static_for<1, PrefetchStages, 1>{}([&](auto prefetch_idx) { - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); // main body @@ -687,14 +743,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( a_shuffle_tmp, a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill( a_copy_lds_window, - a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - a_element_func); + a_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { @@ -703,22 +758,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem transpose_tile2d( b_shuffle_tmp, b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill( b_copy_lds_window, - b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{}), - b_element_func); + b_block_tiles.get(number<(prefetch_idx + 1) % PrefetchStages>{})); } - Base::GlobalPrefetch(a_block_tiles.get(number{}), - a_copy_dram_window, - a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), - b_copy_dram_window, - b_dram_tile_window_step); + a_block_tiles.at(number{}) = + load_tile_with_elementwise(a_copy_dram_window, a_element_func); + + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + + b_block_tiles.at(number{}) = + load_tile_with_elementwise(b_copy_dram_window, b_element_func); + + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); }); i += PrefetchStages; @@ -740,26 +797,24 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); transpose_tile2d(a_shuffle_tmp, a_block_tiles.get(number{})); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); + Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp); } else { Base::LocalPrefill(a_copy_lds_window, - a_block_tiles.get(number{}), - a_element_func); + a_block_tiles.get(number{})); } if constexpr(is_b_row_major && !is_b_load_tr_v()) { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tiles.get(number{})); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); + Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp); } else { Base::LocalPrefill(b_copy_lds_window, - b_block_tiles.get(number{}), - b_element_func); + b_block_tiles.get(number{})); } }); @@ -813,13 +868,16 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem } }; - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const @@ -833,9 +891,13 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem p_smem); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, bool has_hot_loop, TailNumber tail_number, @@ -844,7 +906,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) { constexpr bool hot_loop = hot_loop_.value; constexpr auto tail_num = tail_num_.value; - constexpr auto PassThrough = [](const auto& x) { return x; }; + constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; }; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, PassThrough, @@ -856,20 +918,82 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem return Base::TailHandler(RunPipeline, has_hot_loop, tail_number); } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return PipelineImpl{}.template operator()( a_dram_block_window_tmp, - [](const ADataType& a) { return a; }, + [](auto& e, const ADataType& a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType& b) { return b; }, + [](auto& e, const ADataType& a) { e = a; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const AElementFunction& a_element_func, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + a_element_func, + ck_tile::make_tuple(b_dram_block_window_tmp), + b_element_func, + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + bool has_hot_loop, + TailNumber tail_number, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + has_hot_loop, + tail_number, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index e3b4863392..eb363d59b8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -15,14 +15,23 @@ namespace ck_tile { template struct GemmPipelineAGmemBGmemCRegV1 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockGemm = remove_cvref_t())>; @@ -81,17 +90,25 @@ struct GemmPipelineAGmemBGmemCRegV1 return Policy::template GetSmemSize(); } - template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v>, @@ -133,22 +150,30 @@ struct GemmPipelineAGmemBGmemCRegV1 auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + auto as_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); // A LDS tile window for store auto a_copy_lds_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + auto bs_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); // B LDS tile window for store auto b_copy_lds_window = make_tile_window( @@ -182,13 +207,22 @@ struct GemmPipelineAGmemBGmemCRegV1 // prefetch // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); { // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + // Move each A — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + // Move each B — the enhanced function move_tile_window is executed, which takes a tuple + // as input. + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -198,13 +232,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto a_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); - store_tile(a_copy_lds_window, a_block_tile_tmp); + transpose_tile2d(a_shuffle_tmp, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp); } else { - store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + store_tile(a_copy_lds_window, elementwise_As_res); } // LDS write 0 @@ -212,13 +245,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); - store_tile(b_copy_lds_window, b_block_tile_tmp); + transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp); } else { - store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + store_tile(b_copy_lds_window, elementwise_Bs_res); } } @@ -226,8 +258,8 @@ struct GemmPipelineAGmemBGmemCRegV1 while(iCounter > 0) { // global read i + 1 - a_block_tile = load_tile(a_copy_dram_window); - b_block_tile = load_tile(b_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); block_sync_lds(); @@ -237,22 +269,20 @@ struct GemmPipelineAGmemBGmemCRegV1 block_sync_lds(); // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 if constexpr(is_a_col_major) { auto a_shuffle_tmp_loop = make_static_distributed_tensor( Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp_loop, a_block_tile); - store_tile(a_copy_lds_window, - tile_elementwise_in(a_element_func, a_shuffle_tmp_loop)); + transpose_tile2d(a_shuffle_tmp_loop, elementwise_As_res); + store_tile(a_copy_lds_window, a_shuffle_tmp_loop); } else { - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); } // LDS write i + 1 @@ -260,14 +290,12 @@ struct GemmPipelineAGmemBGmemCRegV1 { auto b_shuffle_tmp_loop = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp_loop, b_block_tile); - store_tile(b_copy_lds_window, - tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + transpose_tile2d(b_shuffle_tmp_loop, elementwise_Bs_res); + store_tile(b_copy_lds_window, b_shuffle_tmp_loop); } else { - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); } iCounter--; @@ -284,20 +312,40 @@ struct GemmPipelineAGmemBGmemCRegV1 return c_block_tile; } - template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, index_t num_loop, void* p_smem) const { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType & b) { return b; }, + [](auto& e, const BDataType & b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp index b151cd6782..c309f8908a 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp @@ -15,30 +15,66 @@ namespace ck_tile { template struct GemmPipelineAGmemBGmemCRegV2 { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; + static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; static constexpr index_t kNPerBlock = BlockGemmShape::kN; static constexpr index_t kKPerBlock = BlockGemmShape::kK; + template + static constexpr index_t GetVectorSizeA() + { + return Problem::VectorSizeA; + } + template + static constexpr index_t GetVectorSizeB() + { + return Problem::VectorSizeB; + } + static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; } + static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA(); } static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB(); } + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; + + static constexpr bool Preshuffle = Problem::Preshuffle; + + static constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + // For the basic gemm pipelien DoubleSmemBuffer set to be false naturally. + static constexpr bool DoubleSmemBuffer = false; + [[nodiscard]] CK_TILE_HOST static const std::string GetName() { // clang-format off return concat('_', "pipeline_AGmemBGmemCRegV2", - concat('x', kMPerBlock, kNPerBlock, kKPerBlock, kBlockSize)); + concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize)); // clang-format on } CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } @@ -56,17 +92,31 @@ struct GemmPipelineAGmemBGmemCRegV2 BPackedSize; } - template (); + } + + template - CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + typename BElementFunction, + typename std::enable_if_t::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, - const BDramBlockWindowTmp& b_dram_block_window_tmp, + const BsDramBlockWindowTmp& b_dram_block_window_tmp, const BElementFunction& b_element_func, index_t num_loop, void* p_smem) const { + + using ADramBlockWindowTmp = + remove_cvref_t{}, AsDramBlockWindowTmp>>; + using BDramBlockWindowTmp = + remove_cvref_t{}, BsDramBlockWindowTmp>>; + static_assert( std::is_same_v> && std::is_same_v>, @@ -98,32 +148,40 @@ struct GemmPipelineAGmemBGmemCRegV2 auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); // A DRAM tile window for load - auto a_copy_dram_window = - make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - a_dram_block_window_tmp.get_window_origin(), - Policy::template MakeADramTileDistribution()); + auto as_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + a_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeADramTileDistribution()); + }, + number{}); // A LDS tile window for store auto a_copy_lds_window = make_tile_window(a_lds_block, make_tuple(number{}, number{}), {0, 0}, - a_copy_dram_window.get_tile_distribution()); + as_copy_dram_window[number<0>{}].get_tile_distribution()); // B DRAM tile window for load - auto b_copy_dram_window = - make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_dram_block_window_tmp.get_window_origin(), - Policy::template MakeBDramTileDistribution()); + auto bs_copy_dram_window = generate_tuple( + [&](auto idx) { + return make_tile_window( + b_dram_block_window_tmp[number{}].get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp[number{}].get_window_origin(), + Policy::template MakeBDramTileDistribution()); + }, + number{}); // B LDS tile window for store auto b_copy_lds_window = make_tile_window(b_lds_block, make_tuple(number{}, number{}), {0, 0}, - b_copy_dram_window.get_tile_distribution()); + bs_copy_dram_window[number<0>{}].get_tile_distribution()); // Block GEMM constexpr auto block_gemm = Policy::template GetBlockGemm(); @@ -153,28 +211,30 @@ struct GemmPipelineAGmemBGmemCRegV2 // prefetch // global read 0 - auto a_block_tile = load_tile(a_copy_dram_window); - auto b_block_tile = load_tile(b_copy_dram_window); + // Load tile — during value loading, an elementwise function is executed for each A0, + // A1, … AN. The values A0, A1, … AN are read by the same thread. + auto elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); + // Load tile — during value loading, an elementwise function is executed for each B0, + // B1, … BN. The values B0, B1, … BN are read by the same thread. + auto elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); { // move to 1 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); // global read 1 - a_block_tile = load_tile(a_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); // global read 1 - b_block_tile = load_tile(b_copy_dram_window); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); } index_t iCounter = num_loop - 2; @@ -189,20 +249,18 @@ struct GemmPipelineAGmemBGmemCRegV2 block_sync_lds(); // move to i + 2 - move_tile_window(a_copy_dram_window, {0, kKPerBlock}); - move_tile_window(b_copy_dram_window, {0, kKPerBlock}); + move_tile_window(as_copy_dram_window, {0, kKPerBlock}); + move_tile_window(bs_copy_dram_window, {0, kKPerBlock}); // LDS write i + 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); // global read i + 2 - a_block_tile = load_tile(a_copy_dram_window); + elementwise_As_res = load_tile_with_elementwise(as_copy_dram_window, a_element_func); // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); // global read i + 2 - b_block_tile = load_tile(b_copy_dram_window); + elementwise_Bs_res = load_tile_with_elementwise(bs_copy_dram_window, b_element_func); iCounter--; @@ -218,11 +276,9 @@ struct GemmPipelineAGmemBGmemCRegV2 block_sync_lds(); // LDS write num_loop - 1 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + store_tile(a_copy_lds_window, elementwise_As_res); - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + store_tile(b_copy_lds_window, elementwise_Bs_res); block_sync_lds(); @@ -241,12 +297,28 @@ struct GemmPipelineAGmemBGmemCRegV2 { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_dram_block_window_tmp, - [](const BDataType & b) { return b; }, + [](auto& e, const BDataType & b) { e = b; }, num_loop, p_smem); } + + template ::value && + !is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + return operator()(ck_tile::make_tuple(a_dram_block_window_tmp), + ck_tile::make_tuple(b_dram_block_window_tmp), + num_loop, + p_smem); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 52bd07c9e2..c73fa29245 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -5,16 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { -template @@ -22,18 +25,49 @@ struct GemmPipelineProblemBase { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; // actually AccDataType - using ComputeDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; // actually AccDataType static constexpr bool FixedVectorSize = FixedVectorSize_; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; + + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool ComputeDataTypeIsTuple = is_detected::value; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + static constexpr bool ALayoutIsTuple = is_detected::value; + static constexpr bool BLayoutIsTuple = is_detected::value; + + using ComputeDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + using AsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + using BsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ComputeDataType = remove_cvref_t{}, ComputeDataTypeTuple>>; + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using ALayout = remove_cvref_t{}, AsLayoutTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using BLayout = remove_cvref_t{}, BsLayoutTuple>>; static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; @@ -66,7 +100,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; @@ -84,7 +118,7 @@ struct GemmPipelineProblemBase { constexpr index_t PackedSize = ck_tile::numeric_traits>::PackedSize; - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { constexpr index_t pixels_per_thread = BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; @@ -125,7 +159,7 @@ struct GemmPipelineProblemBase { return VectorSizeA_; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return kPadK ? 1 : GetAlignmentA(); } @@ -140,7 +174,7 @@ struct GemmPipelineProblemBase { return VectorSizeB_; } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { return kPadN ? 1 : GetAlignmentB(); } @@ -161,35 +195,40 @@ struct GemmPipelineProblemBase }(); }; -// Alias for GemmPipelineProblem -template -using GemmPipelineProblem = GemmPipelineProblemBase; -template @@ -197,18 +236,48 @@ struct UniversalGemmPipelineProblem { using Traits = remove_cvref_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; // actually AccDataType - using ComputeDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; // actually AccDataType + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; static constexpr bool FixedVectorSize = FixedVectorSize_; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + static constexpr bool ComputeDataTypeIsTuple = is_detected::value; + static constexpr bool ADataTypeIsTuple = is_detected::value; + static constexpr bool BDataTypeIsTuple = is_detected::value; + + static constexpr bool ALayoutIsTuple = is_detected::value; + static constexpr bool BLayoutIsTuple = is_detected::value; + + using ComputeDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + using AsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + using BsLayoutTuple = std:: + conditional_t, remove_cvref_t>>; + + using AsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using BsDataTypeTuple = std::conditional_t, + remove_cvref_t>>; + + using ComputeDataType = remove_cvref_t{}, ComputeDataTypeTuple>>; + using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; + using ALayout = remove_cvref_t{}, AsLayoutTuple>>; + using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; + using BLayout = remove_cvref_t{}, BsLayoutTuple>>; static constexpr bool TransposeC = Traits::TransposeC; static constexpr index_t NumWaveGroups = Traits::NumWaveGroups; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8d47ab878e..c8f874acd6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -356,11 +356,14 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA() { - using ALayout = remove_cvref_t; - using ADataType = remove_cvref_t; + using AsLayout = remove_cvref_t; + using AsDataType = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using ALayout = remove_cvref_t{}, AsLayout>>; + using ADataType = remove_cvref_t{}, AsDataType>>; + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB() { - using BLayout = remove_cvref_t; - using BDataType = remove_cvref_t; + using BsLayout = remove_cvref_t; + using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + using BLayout = remove_cvref_t{}, BsLayout>>; + using BDataType = remove_cvref_t{}, BsDataType>>; + if constexpr(std::is_same_v) { return GetGlobalVectorLoadSize CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using ALayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -491,6 +495,8 @@ struct UniversalGemmBasePolicy Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using ALayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { @@ -518,8 +524,6 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using BLayout = remove_cvref_t; - constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -527,6 +531,8 @@ struct UniversalGemmBasePolicy Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { @@ -554,7 +560,8 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution() { - using ALayout = remove_cvref_t; + using ALayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; @@ -574,7 +581,8 @@ struct UniversalGemmBasePolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution() { - using BLayout = remove_cvref_t; + using BLayout = remove_cvref_t< + std::tuple_element_t{}, remove_cvref_t>>; static_assert(std::is_same_v); constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 64900c9a97..96203b2cd2 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -10,8 +10,8 @@ namespace ck_tile { template struct TileGemmTraits @@ -23,9 +23,9 @@ struct TileGemmTraits // TODO this can't be hardcoded here! Should be in policy! static constexpr int _VectorSize = 16; - using ALayout = ALayout_; - using BLayout = BLayout_; - using CLayout = CLayout_; + using AsLayout = AsLayout_; + using BsLayout = BsLayout_; + using CLayout = CLayout_; static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; @@ -36,8 +36,8 @@ template @@ -76,8 +76,8 @@ using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits; + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy { - using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV1; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockWeightPreshuffle = remove_cvref_t())>; @@ -188,7 +198,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 } } - template + template ::value && + !is_detected::value, + bool>* = nullptr, + index_t UnaryOpSize_ = 8> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -296,14 +312,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 NIterPerWarp> b_flat_dram_windows; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + statically_indexed_array, NIterPerWarp> b_warp_tensor; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_2; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -313,7 +329,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -361,7 +378,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -394,7 +412,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -431,7 +450,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -455,7 +475,33 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 return c_block_tile; } - template + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp[number<0>{}], + [](const ADataType & a) { return a; }, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem); + } + + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, @@ -463,7 +509,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 { return operator()( a_dram_block_window_tmp, - [](const ADataType & a) { return a; }, + [](auto& e, const ADataType & a) { e = a; }, b_flat_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 129eac6557..670f4b0575 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" @@ -53,14 +54,23 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 { using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + + using AElementWise = remove_cvref_t; + using BElementWise = remove_cvref_t; using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using AsLayout = remove_cvref_t; + using BsLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + + using ALayout = remove_cvref_t>; + using BLayout = remove_cvref_t>; + + using ADataType = remove_cvref_t>; + using BDataType = remove_cvref_t>; using BlockWeightPreshuffle = remove_cvref_t())>; @@ -502,7 +512,11 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 template + typename AElementFunction, + typename std::enable_if_t::value && + !is_detected::value, + bool>* = nullptr, + index_t UnaryOpSize_ = 8> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -619,19 +633,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 b_flat_distribution); // pingpong buffer for B + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + statically_indexed_array< statically_indexed_array, NIterPerWarp> b_flat_dram_windows; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_ping; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_pong; // Prefetch A0 @@ -647,7 +661,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -694,7 +709,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -770,7 +786,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -850,7 +867,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -1001,8 +1019,37 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 return c_block_tile; } + // called from universal gemm kernel + template ::value && + is_detected::value, + bool>* = nullptr> + CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + [[maybe_unused]] const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + [[maybe_unused]] const BElementFunction& b_element_func, + index_t num_loop, + void* p_smem_ping, + void* p_smem_pong) const + { + return operator()( + a_dram_block_window_tmp[number<0>{}], + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp[number<0>{}], + num_loop, + p_smem_ping, + p_smem_pong); + } + // called from general gemm kernel - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, @@ -1019,9 +1066,13 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 } // called from grouped gemm kernel - template + template ::value && + !is_detected::value, + bool>* = nullptr> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, - const BDramBlockWindowTmp& b_flat_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, void* __restrict__ p_smem_0, diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 182d9251b1..f75d02f1a6 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -5,19 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" namespace ck_tile { -template +template struct BlockGemmAQuantBase { using AQDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - static constexpr index_t UnaryOpSize = UnaryOpSize_; template CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) { @@ -42,23 +42,6 @@ struct BlockGemmAQuantBase } return scale_reg_f; } - - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } }; // A is block window on shared memory @@ -66,7 +49,9 @@ struct BlockGemmAQuantBase // Consecutive kQuantGroupSize elements of A are quantized with a separate scale. // B is block window on shared memory // C is block distributed tensor -template +template struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { private: @@ -172,6 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase using Base = BlockGemmAQuantBase; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -292,7 +278,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -302,7 +288,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 7e28ea8fa9..077d0d8fe2 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,19 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" namespace ck_tile { -template +template struct BlockGemmBQuantBase { using BQDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - static constexpr index_t UnaryOpSize = UnaryOpSize_; template CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) { @@ -42,24 +42,6 @@ struct BlockGemmBQuantBase } return scale_reg_f; } - - // can be inherited from A - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } }; // A is block window on shared memory @@ -67,7 +49,9 @@ struct BlockGemmBQuantBase // Consecutive kQuantGroupSize elements of B are quantized with a separate scale. // B is block window on shared memory // C is block distributed tensor -template +template struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { private: @@ -170,6 +154,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using Base = BlockGemmBQuantBase; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -291,7 +276,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -301,7 +286,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp index 44c6cd66c6..f505efe4e0 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp @@ -44,6 +44,10 @@ struct TileGemmQuantTraits using AQLayout = AQLayout_; using BQLayout = BQLayout_; + // TODO: It should be replaced to single value + using AsLayout = ALayout_; + using BsLayout = BLayout_; + static constexpr bool TransposeC = false; static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index cf4eca7a2d..6fcef5502e 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -23,7 +23,8 @@ struct GroupedConvFwdKernelArgs using ConvToGemmFwdTransformer = TransformConvFwdToGemm; + GroupedConvTraitsType_::ConvSpecialization, + true>; // Split N enabled static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; template < @@ -56,7 +57,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0]; + // GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0]; GemmBatch = args.G_; @@ -94,6 +95,19 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 1D convolution (NWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NWGC layout + input_batch_stride = args.C_ * args.input_spatial_lengths_[0]; + output_batch_stride = args.K_ * args.output_spatial_lengths_[0]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0]; } template < @@ -133,7 +147,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + // Note: GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1]; GemmBatch = args.G_; @@ -171,6 +185,21 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 2D convolution (NHWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NHWGC layout + input_batch_stride = + args.C_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1]; + output_batch_stride = + args.K_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1]; } template < @@ -217,8 +246,7 @@ struct GroupedConvFwdKernelArgs k_batch = args.k_batch; - GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * - args.output_spatial_lengths_[2]; + // Note: GemmM will be set after Split-N calculation GemmN = args.K_; GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] * args.filter_spatial_lengths_[2]; @@ -257,6 +285,22 @@ struct GroupedConvFwdKernelArgs 1, std::multiplies()); group_stride_c = args.K_; + + // Initialize Split-N support fields for 3D convolution (NDHWGC layout) + // Get the actual split N from transformer + n_per_split = conv_to_gemm_transformer.GetN(); + original_n = conv_to_gemm_transformer.GetOriginalN(); + n_splits = ck_tile::integer_divide_ceil(original_n, n_per_split); + + // Calculate batch strides for NDHWGC layout + input_batch_stride = args.C_ * args.input_spatial_lengths_[0] * + args.input_spatial_lengths_[1] * args.input_spatial_lengths_[2]; + output_batch_stride = args.K_ * args.output_spatial_lengths_[0] * + args.output_spatial_lengths_[1] * args.output_spatial_lengths_[2]; + + // Update GemmM to use split N (not original N) + GemmM = n_per_split * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] * + args.output_spatial_lengths_[2]; } using AGridDescMK = remove_cvref_t< @@ -297,6 +341,13 @@ struct GroupedConvFwdKernelArgs long_index_t group_stride_a; long_index_t group_stride_b; long_index_t group_stride_c; + + // Split-N support fields - initialize to safe defaults + index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2) + index_t n_per_split = 1; // Batches per split (N_ from transformer) + index_t original_n = 1; // Original batch size before splitting + index_t input_batch_stride = 0; // Stride to next batch in input tensor + index_t output_batch_stride = 0; // Stride to next batch in output tensor }; /// @brief The Grouped Convolution Forward kernel template. @@ -392,10 +443,10 @@ struct GroupedConvolutionForwardKernel // clang-format on } - CK_TILE_HOST static constexpr auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) + CK_TILE_HOST static auto GridSize(const GroupedConvFwdKernelArgsSpecialized& kargs) { return dim3( - TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch); + TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.n_splits); } CK_TILE_HOST static auto BlockSize() @@ -430,6 +481,17 @@ struct GroupedConvolutionForwardKernel } } + // Check Split-K and Split-N conflict (both use blockIdx.z) + if(kargs.k_batch > 1 && kargs.n_splits > 1) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR( + "Cannot use both Split-K and Split-N simultaneously (both use blockIdx.z)!"); + } + return false; + } + const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}]; const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}]; @@ -768,10 +830,26 @@ struct GroupedConvolutionForwardKernel const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY); const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY); - // options - const InDataType* a_ptr = static_cast(kargs.in_ptr) + group_offset_a; - const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + group_offset_b; - OutDataType* c_ptr = static_cast(kargs.out_ptr) + group_offset_c; + // Split-N handling: Get which split this workgroup handles + const auto blockIdZ = __builtin_amdgcn_readfirstlane(blockIdx.z); + + // Calculate batch offset for this split + const index_t batch_offset = __builtin_amdgcn_readfirstlane(blockIdZ * kargs.n_per_split); + + // Calculate memory offsets for this split + const long_index_t input_batch_offset = static_cast(batch_offset) * + static_cast(kargs.input_batch_stride); + const long_index_t output_batch_offset = + static_cast(batch_offset) * + static_cast(kargs.output_batch_stride); + + // Adjust pointers: combine group offset and batch offset + const InDataType* a_ptr = + static_cast(kargs.in_ptr) + group_offset_a + input_batch_offset; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + + group_offset_b; // No batch offset for weights! + OutDataType* c_ptr = + static_cast(kargs.out_ptr) + group_offset_c + output_batch_offset; // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp index c468ae4398..2663d8a494 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp @@ -24,7 +24,7 @@ struct TransformConvFwdToGemm static constexpr auto I3 = number<3>{}; static constexpr auto I4 = number<4>{}; static constexpr auto I5 = number<5>{}; -#if 0 // TODO: Enable these functionalities + template static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths, const ConvDimsType& strides, @@ -42,24 +42,40 @@ struct TransformConvFwdToGemm template static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths, - const ConvDimsType& a_g_n_c_wis_strides, - const ConvDimsType& c_g_n_k_wos_lengths, - const ConvDimsType& c_g_n_k_wos_strides) + const ConvDimsType& c_g_n_k_wos_lengths) { + // Calculate strides internally assuming contiguous memory layout + ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides; + const index_t num_dims = a_g_n_c_wis_lengths.size(); + + // Calculate strides for input tensor (innermost to outermost) + a_g_n_c_wis_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 0; i--) + { + a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1]; + } + + // Calculate strides for output tensor + c_g_n_k_wos_strides[num_dims - 1] = 1; + for(index_t i = num_dims - 2; i >= 0; i--) + { + c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1]; + } + const long_index_t a_element_space_size = calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1); const long_index_t c_element_space_size = calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1); - const long_index_t element_space_size = math::max(a_element_space_size * sizeof(ADataType), - c_element_space_size * sizeof(CDataType)); - constexpr long_index_t TwoGB = (long_index_t{1} << 31); + const long_index_t element_space_size = ck_tile::max( + a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType)); + constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB const IndexType N = a_g_n_c_wis_lengths[I1]; if(element_space_size > TwoGB) { // Minimum divisor of N to not exceed 2GB - const auto divisor = math::integer_divide_ceil(element_space_size, TwoGB); + const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB); if(divisor <= static_cast(N)) { @@ -70,7 +86,8 @@ struct TransformConvFwdToGemm { if(N % least_divisor == 0) { - return N / least_divisor; + IndexType result = N / least_divisor; + return result; } } // Not found, process one Convolution N per block @@ -90,9 +107,12 @@ struct TransformConvFwdToGemm return N; } } -#endif public: + // Public getter methods for Split-N support + CK_TILE_HOST constexpr IndexType GetN() const { return N_; } + CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; } + CK_TILE_HOST constexpr TransformConvFwdToGemm() {} template @@ -100,6 +120,7 @@ struct TransformConvFwdToGemm TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base) : G_{static_cast(transform_conv_fwd_to_gemm_base.G_)}, N_{static_cast(transform_conv_fwd_to_gemm_base.N_)}, + original_N_{static_cast(transform_conv_fwd_to_gemm_base.original_N_)}, Di_{static_cast(transform_conv_fwd_to_gemm_base.Di_)}, Hi_{static_cast(transform_conv_fwd_to_gemm_base.Hi_)}, Wi_{static_cast(transform_conv_fwd_to_gemm_base.Wi_)}, @@ -168,18 +189,14 @@ struct TransformConvFwdToGemm std::is_same_v>); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { N_ = c_g_n_k_wos_lengths[I1]; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } template >); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities + + // Store original N + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = c_g_n_k_wos_lengths[I1]; + original_N_ = N_; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } template >); static_assert(std::is_same_v> || std::is_same_v>); -#if 0 // TODO: Enable these functionalities + + // Store original N before potential splitting + original_N_ = c_g_n_k_wos_lengths[I1]; + if constexpr(SplitN) { - N_ = GetSplitedNSize( - a_g_n_c_wis_lengths, a_g_n_c_wis_strides, c_g_n_k_wos_lengths, c_g_n_k_wos_strides); + N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths); } else { - N_ = c_g_n_k_wos_lengths[I1]; + N_ = original_N_; } -#endif - N_ = c_g_n_k_wos_lengths[I1]; } #if 0 // TODO: Enable these functionalities @@ -1417,7 +1435,7 @@ struct TransformConvFwdToGemm } } - IndexType G_, N_; + IndexType G_, N_, original_N_; IndexType Di_, Hi_, Wi_; IndexType Do_, Ho_, Wo_; IndexType Z_, Y_, X_; diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 6998b358d8..0181a3291f 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -134,7 +134,11 @@ struct Layernorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index e7f4ce0ba8..32586a6343 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -124,7 +124,11 @@ struct Rmsnorm2dFwd return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp index b70e996617..2553b19fd8 100644 --- a/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp @@ -93,7 +93,11 @@ struct MoeSmoothquant return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp index 7dc913901e..e0ea9692c5 100644 --- a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -82,7 +82,11 @@ struct Smoothquant return dim3(integer_divide_ceil(hargs.m, Block_M)); } - CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + CK_TILE_HOST static constexpr auto BlockSize() + { + return is_wave32() ? Problem::BlockShape::template GetBlockSize() + : Problem::BlockShape::template GetBlockSize(); + } // clang-format off template struct t2s; diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp index 3884902bbf..573571bc07 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp @@ -59,6 +59,7 @@ template = 1 && NDimSpatial <= 3, bool>::type = false> struct ReferenceConvFwd : public device::BaseOperator { @@ -163,8 +164,18 @@ struct ReferenceConvFwd : public device::BaseOperator k, c, x); - v_acc += - ck::type_convert(v_in) * ck::type_convert(v_wei); + if constexpr(is_same_v) + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else + { + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); + } } } } @@ -238,8 +249,18 @@ struct ReferenceConvFwd : public device::BaseOperator c, y, x); - v_acc += ck::type_convert(v_in) * - ck::type_convert(v_wei); + if constexpr(is_same_v) + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else + { + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); + } } } } @@ -327,8 +348,18 @@ struct ReferenceConvFwd : public device::BaseOperator z, y, x); - v_acc += ck::type_convert(v_in) * - ck::type_convert(v_wei); + if constexpr(is_same_v) + { + v_acc += ck::type_convert( + ck::type_convert(v_in)) * + ck::type_convert( + ck::type_convert(v_wei)); + } + else + { + v_acc += ck::type_convert(v_in) * + ck::type_convert(v_wei); + } } } } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp index ed07e53e6d..8b9b973b2d 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp @@ -25,6 +25,12 @@ template struct ReferenceGemm : public device::BaseOperator { + + using ElementDataTypeA = + ck::conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + ck::conditional_t, float, ComputeTypeB>; + // Argument struct Argument : public device::BaseArgument { @@ -63,8 +69,8 @@ struct ReferenceGemm : public device::BaseOperator const int K = arg.a_m_k_.mDesc.GetLengths()[1]; AccDataType v_acc{0}; - ComputeTypeA v_a{0}; - ComputeTypeB v_b{0}; + ElementDataTypeA v_a{0}; + ElementDataTypeB v_b{0}; for(int k = 0; k < K; ++k) { @@ -77,16 +83,16 @@ struct ReferenceGemm : public device::BaseOperator else i4 = (i4x2 >> 4) & 0xf; i4 = i4 - 8; - v_a = type_convert(i4); + v_a = type_convert(i4); } else if constexpr(is_same_v) { // TODO: add support for ColMajor layout as well if(k % 2 == 1) - v_a = type_convert( + v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<1>{}))); else - v_a = type_convert( + v_a = type_convert( f4_t(arg.a_m_k_(m, k).template unpack<>(Number<0>{}))); } else if constexpr(is_same_v || @@ -94,7 +100,7 @@ struct ReferenceGemm : public device::BaseOperator is_same_v || is_same_v) { - v_a = type_convert( + v_a = type_convert( arg.a_m_k_(m, k).unpack(k % ADataType::packed_size)); } else @@ -111,16 +117,16 @@ struct ReferenceGemm : public device::BaseOperator else i4 = (i4x2 >> 4) & 0xf; i4 = i4 - 8; - v_b = type_convert(i4); + v_b = type_convert(i4); } else if constexpr(is_same_v) { // TODO: add support for RowMajor layout as well if(k % 2 == 1) - v_b = type_convert( + v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<1>{}))); else - v_b = type_convert( + v_b = type_convert( f4_t(arg.b_k_n_(k, n).template unpack<>(Number<0>{}))); } else if constexpr(is_same_v || @@ -128,7 +134,7 @@ struct ReferenceGemm : public device::BaseOperator is_same_v || is_same_v) { - v_b = type_convert( + v_b = type_convert( arg.b_k_n_(k, n).unpack(k % BDataType::packed_size)); } else @@ -136,8 +142,18 @@ struct ReferenceGemm : public device::BaseOperator arg.b_element_op_(v_b, arg.b_k_n_(k, n)); } - v_acc += - ck::type_convert(v_a) * ck::type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { // only for tf32 now + v_acc += + ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else + { + v_acc += + ck::type_convert(v_a) * ck::type_convert(v_b); + } } CDataType v_c{0}; diff --git a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp index 28274a5154..cf30bc7dda 100644 --- a/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp @@ -38,6 +38,10 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) const CDEElementwiseOperation c_element_op) { using RowMajor = ck::tensor_layout::gemm::RowMajor; + using ElementDataTypeA = + ck::conditional_t, float, ComputeTypeA>; + using ElementDataTypeB = + ck::conditional_t, float, ComputeTypeB>; const int row_idx = blockIdx.x * blockDim.x + threadIdx.x; const int col_idx = blockIdx.y * blockDim.y + threadIdx.y; @@ -46,8 +50,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) { AccDataType v_acc{0}; - ComputeTypeA v_a{0}; - ComputeTypeB v_b{0}; + ElementDataTypeA v_a{0}; + ElementDataTypeB v_b{0}; CDataType v_c{0}; for(int k_idx = 0; k_idx < k; ++k_idx) @@ -76,7 +80,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) // apply b_element_op b_element_op(v_b, p_b_grid[element_idx_b]); // multiply and accumulate - v_acc += type_convert(v_a) * type_convert(v_b); + if constexpr(is_same_v && + is_same_v) + { // only for tf32 now + v_acc += ck::type_convert(ck::type_convert(v_a)) * + ck::type_convert(ck::type_convert(v_b)); + } + else + { + v_acc += type_convert(v_a) * type_convert(v_b); + } } // apply c_element_op c_element_op(v_c, v_acc); diff --git a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp index 7164f345cd..9aeca39718 100644 --- a/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp +++ b/library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp @@ -16,6 +16,7 @@ namespace instance { // aliasing, for commonly used data type using F64 = double; using F32 = float; +using TF32 = ck::tf32_t; using F16 = ck::half_t; using BF16 = ck::bhalf_t; using I8 = int8_t; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp index 82c01a634b..568f0e0dc4 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_dynamic_op_instance.hpp @@ -16,6 +16,7 @@ namespace instance { using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp index 768fcbada0..52c389d020 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp @@ -24,6 +24,7 @@ using BF8 = ck::bf8_t; using BF16 = ck::bhalf_t; using F16 = ck::half_t; using F32 = float; +using TF32 = ck::tf32_t; template using S = ck::Sequence; @@ -199,7 +200,7 @@ using device_grouped_conv_fwd_xdl_f16_nchw_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 8, 1, 8>, 1>, - // 32x32 instance + // 32x32 instance DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>, @@ -284,7 +285,45 @@ using device_grouped_conv_fwd_xdl_f32_instances = std::tuple< DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4>, DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, - DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4> + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4> + // clang-format on + >; + +template , + typename OutElementOp = PassThrough> +using device_grouped_conv_fwd_xdl_f32_tf32_instances = std::tuple< + // clang-format off + //########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| AComputeType| BComputeType| + //########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| DATATYPE | DATATYPE | + //########################################| | | | | | | | | | | | Operation| Operation| Operation| | | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl | + //########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // generic instance + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + // instances for small conv.K and conv.C + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 1, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 16>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 8, 1, 8>, 4, TF32, TF32>, + DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, 1, 1, S<1, 16, 1, 16>, 4, TF32, TF32> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp index 545826650c..5a26abecc2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp @@ -443,6 +443,12 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(op_ptrs); + } #endif #ifdef CK_ENABLE_FP8 diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp index 43411b0031..11e827878c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp.hpp @@ -215,6 +215,14 @@ struct DeviceOperationInstanceFactory && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc index aaaacb0d18..045d1623cf 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_bias_clamp_xdl.inc @@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_insta PassThrough, AddClamp>>>& instances); +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector && is_same_v && + is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + op_ptrs); + } #endif } #endif // CK_USE_XDL diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc index d5a8a5344a..b0061b966d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_clamp_xdl.inc @@ -578,6 +578,22 @@ void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, Clamp>>>& instances); +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances); + void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>& instances); + #endif #ifdef CK_ENABLE_INT8 @@ -159,7 +160,8 @@ template + typename AComputeType, + typename BComputeType = AComputeType> struct DeviceOperationInstanceFactory> + AComputeType, + BComputeType>> { using DeviceOp = DeviceGroupedConvFwdMultipleABD; + AComputeType, + BComputeType>; static auto GetInstances() { @@ -207,7 +211,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv3d_fwd_xdl_dynamic_op_ndhwgc_gkzyxc_ndhwgk_f16_instances( op_ptrs); @@ -244,7 +248,7 @@ struct DeviceOperationInstanceFactory && is_same_v && - is_same_v && is_same_v) + is_same_v && is_same_v) { add_device_grouped_conv2d_fwd_xdl_dynamic_op_nhwgc_gkyxc_nhwgk_f16_instances( op_ptrs); diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index a3f2515099..af6041bbc5 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -559,6 +559,22 @@ void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances( std::vector>>>& instances); #endif + +#ifdef CK_USE_XDL // Layout(A, B, C) = [Col, Row, Row] void add_device_gemm_quantization_xdl_c_shuffle_i8_i8_i8_km_kn_mn_instances( std::vector>>>& instances); +#endif + +#ifdef CK_USE_WMMA +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>>& + instances); + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>>& + instances); +#endif template && is_same_v && @@ -195,7 +258,9 @@ struct DeviceOperationInstanceFactory && is_same_v && @@ -206,7 +271,9 @@ struct DeviceOperationInstanceFactory && is_same_v && @@ -217,12 +284,117 @@ struct DeviceOperationInstanceFactory>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif + + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory>> +{ + using DeviceOp = DeviceGemmMultipleDSplitK>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#ifdef CK_USE_WMMA + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v) + { + add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + op_ptrs); + } + } + } +#endif + return op_ptrs; } }; @@ -230,4 +402,4 @@ struct DeviceOperationInstanceFactory>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>{}); + add_device_operation_instances( + instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt index bda9149227..6a776b4943 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/CMakeLists.txt @@ -2,7 +2,7 @@ set(GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP) include(ShardInstantiation) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -11,7 +11,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -20,7 +20,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -29,7 +29,16 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in + NUM_SHARDS 16 + SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16x16_instances @@ -38,7 +47,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instances @@ -47,7 +56,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances @@ -58,7 +67,7 @@ generate_sharded_instantiations( ) # large tensor # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -67,7 +76,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -76,7 +85,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/large_tensor ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -87,7 +96,7 @@ generate_sharded_instantiations( ) # merged groups # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_bf16_instances @@ -96,7 +105,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f16_instances @@ -105,7 +114,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/merged_groups ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_merged_groups_ndhwgc_gkzyxc_ndhwgk_f32_instances @@ -116,7 +125,7 @@ generate_sharded_instantiations( ) #mem # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instances @@ -125,7 +134,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instances @@ -134,7 +143,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instances @@ -144,7 +153,7 @@ generate_sharded_instantiations( OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instances @@ -153,7 +162,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instances @@ -162,7 +171,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/mem ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances @@ -173,7 +182,7 @@ generate_sharded_instantiations( ) #comp # NDHWGC, GKZYXC, NDHWGK - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances @@ -182,7 +191,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances @@ -191,7 +200,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instances @@ -200,7 +209,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_2x_instances @@ -209,7 +218,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_2x_instances @@ -218,7 +227,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_part2_instances @@ -227,7 +236,7 @@ generate_sharded_instantiations( SRC_LIST GROUPED_conv3d_FWD_BIAS_BNORM_CLAMP OUTPUT_DIR ${GENERATED_DIR}/xdl/comp ) - + set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) generate_sharded_instantiations( INSTANCES_NAME device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_part2_instances diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in new file mode 100644 index 0000000000..d7f3c87b83 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_bnorm_clamp/xdl/device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instance.in @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances = + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + BiasNormalizeInInferClamp, + TF32, + TF32>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances_shard( + device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwdDefault, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + BiasNormalizeInInferClamp>, + Shards, + ShardIndex>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt index 3bd6916cf0..bcc7020ca9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/CMakeLists.txt @@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp -) + + xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..328838bff2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_bias_clamp/xdl/device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple, + F32, + PassThrough, + PassThrough, + AddClamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwdDefault, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1P0, + Tuple, + AddClamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple, + NDHWGK, + ConvFwd1x1S1P0, + Tuple, + AddClamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt index 234533244e..059d22f8d2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/CMakeLists.txt @@ -23,6 +23,8 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_mem_intra_instance.cpp xdl/comp/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_comp_instance.cpp -) + + xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp + ) add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD}) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp new file mode 100644 index 0000000000..a1bf6562c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd_clamp/xdl/device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_fp32_tf32_instance.cpp @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances( + std::vector, + NDHWGK, + F32, + F32, + Tuple<>, + F32, + PassThrough, + PassThrough, + Clamp, + TF32, + TF32>>>& instances) +{ + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwdDefault, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1P0, + Tuple<>, + Clamp>{}); + add_device_operation_instances(instances, + device_grouped_conv_fwd_xdl_f32_tf32_instances<3, + NDHWGC, + GKZYXC, + Tuple<>, + NDHWGK, + ConvFwd1x1S1P0, + Tuple<>, + Clamp>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt index 5d50902be8..a5b4fb5df4 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/quantization/CMakeLists.txt @@ -20,6 +20,12 @@ list(APPEND GEMM_QUANT_SRC gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp gemm/device_gemm_quantization_dl_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) +list(APPEND GEMM_QUANT_SRC + gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp + gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp + gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp + gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp) + add_instance_library(device_quantization_instance ${CONV2D_PERLAYER_QUANT_SRC} ${CONV2D_PERCHANNEL_QUANT_SRC} diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp new file mode 100644 index 0000000000..3737f0a958 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_quantization_common.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 128, 32, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 160, 64, 8, 8, 16, 16, 2, 5, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 64, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Col, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 64, 80, 64, 8, 8, 16, 16, 1, 5, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 64, 1, 2>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Row, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +template +using device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances = std::tuple< + // clang-format off + //################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GemmSpec| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CDEShuffleBlockTransfer| BlkGemm| BlkGemm| ComputeTypeA| ComputeTypeB| + //################################| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| | Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVectors| PipeSched| PipelineVer| | | + //################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _MBlock_MPerBlock| | | | | | + //################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | _NBlock_NPerBlock| | | | | | + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 128, 32, 8, 8, 16, 16, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 128, 128, 64, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 32, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 256, 128, 256, 64, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 1, 1, S<1, 32, 1, 8>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t>, + DeviceGemmMultipleD_Wmma_CShuffleV3< Row, Col, Empty_Tuple, Row, int8_t, int8_t, Empty_Tuple, int8_t, int32_t, int32_t, PassThrough, PassThrough, OutElementOp, MNKPadding, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, false, 1, 1, S<1, 16, 1, 4>, S<1>, GemmPipelineScheduler, GemmPipeline, int8_t, int8_t> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..a3838bb398 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..31ff723166 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_km_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..07a632a77c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_kn_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..ed9cc908ef --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v3>{}); + add_device_operation_instances( + instances, + device_gemm_quantization_wmma_c_shuffle_i8_i8_i8_mk_nk_mn_instances< + Mul_Clamp, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp index e7c2500fef..a4eb29c7a1 100644 --- a/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp +++ b/library/src/tensor_operation_instance/gpu/quantization/gemm/gemm_quantization_common.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -33,7 +33,8 @@ using Relu_Mul_Clamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp< using Add_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; using Add_Relu_Mul_Clamp = ck::tensor_operation::element_wise::Add_Activation_Mul_Clamp; -static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; +static constexpr auto MNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; } // namespace instance } // namespace device diff --git a/library/src/utility/host_tensor.cpp b/library/src/utility/host_tensor.cpp index 7211552641..02bd562e43 100644 --- a/library/src/utility/host_tensor.cpp +++ b/library/src/utility/host_tensor.cpp @@ -53,7 +53,7 @@ std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc) os << "strides {"; LogRange(os, desc.GetStrides(), ", "); - os << "}"; + os << "} "; return os; } diff --git a/profiler/include/profiler/profile_gemm_quantization_impl.hpp b/profiler/include/profiler/profile_gemm_quantization_impl.hpp new file mode 100644 index 0000000000..a115a41a34 --- /dev/null +++ b/profiler/include/profiler/profile_gemm_quantization_impl.hpp @@ -0,0 +1,231 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/gpu/quantization/gemm_quantization.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +namespace ck { +namespace profiler { + +template +bool profile_gemm_quantization_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + float requant_scale = 0.03f) +{ + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); + Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_device_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + default: + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + } + + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + using MulClamp = ck::tensor_operation::element_wise::Activation_Mul_Clamp; + + using AElementOp = PassThrough; + using BElementOp = PassThrough; + using ActivationOp = PassThrough; + using CDEElementOp = MulClamp; + + const auto a_element_op = AElementOp{}; + const auto b_element_op = BElementOp{}; + const auto cde_element_op = CDEElementOp{requant_scale, ActivationOp{}}; + + using DeviceOp = ck::tensor_operation::device::DeviceGemmMultipleD< + ALayout, + BLayout, + ck::Tuple<>, + ELayout, + ADataType, + BDataType, + ck::Tuple<>, + EDataType, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::Activation_Mul_Clamp>; + + // get device op instances + const auto op_ptrs = ck::tensor_operation::device::instance::DeviceOperationInstanceFactory< + DeviceOp>::GetInstances(); + + std::cout << "found " << op_ptrs.size() << " instances" << std::endl; + + // run reference + if(do_verification) + { + Tensor c_m_n({M, N}); + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, a_element_op, b_element_op, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n)); + } + } + } + + DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.mDesc.GetElementSpaceSize()); + DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a_device_buf.ToDevice(a_m_k.mData.data()); + b_device_buf.ToDevice(b_k_n.mData.data()); + + std::string best_op_name; + float best_ave_time = 0; + float best_tflops = 0; + float best_gb_per_sec = 0; + + bool pass = true; + + // profile device operation instances + for(auto& op_ptr : op_ptrs) + { + auto argument_ptr = op_ptr->MakeArgumentPointer(a_device_buf.GetDeviceBuffer(), + b_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a_element_op, + b_element_op, + cde_element_op); + + auto invoker_ptr = op_ptr->MakeInvokerPointer(); + + std::string op_name = op_ptr->GetTypeString(); + + if(op_ptr->IsSupportedArgument(argument_ptr.get())) + { + // re-init E to zero before profiling a kernel + e_device_buf.SetZero(); + + float ave_time = + invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, time_kernel}); + + std::size_t flop = std::size_t(2) * M * N * K; + + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(EDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " + << gb_per_sec << " GB/s, " << op_name << std::endl; + + if(tflops > best_tflops) + { + best_op_name = op_name; + best_tflops = tflops; + best_ave_time = ave_time; + best_gb_per_sec = gb_per_sec; + } + + if(do_verification) + { + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = pass && ck::utils::check_err(e_m_n_device_result, e_m_n_host_result); + + if(do_log) + { + LogRangeAsType( + std::cout << "e_m_n_device_result: ", e_m_n_device_result.mData, ",") + << std::endl; + + LogRangeAsType( + std::cout << "e_m_n_host_result: ", e_m_n_host_result.mData, ",") + << std::endl; + } + } + } + else + { + std::cout << op_name << " does not support this problem" << std::endl; + } + } + + std::cout << "Best Perf: " << best_ave_time << " ms, " << best_tflops << " TFlops, " + << best_gb_per_sec << " GB/s, " << best_op_name << std::endl; + + return pass; +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 7cfdc5bfc9..31f684fe75 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -32,6 +32,7 @@ set(PROFILER_OPS profile_conv_tensor_rearrange.cpp profile_transpose.cpp profile_permute_scale.cpp + profile_gemm_quantization.cpp ) if(SUPPORTED_GPU_TARGETS MATCHES "gfx9") @@ -112,6 +113,10 @@ if(DL_KERNELS) list(APPEND PROFILER_OPS profile_grouped_conv_bwd_weight.cpp) endif() +if(CK_ENABLE_INT8) + list(APPEND PROFILER_OPS profile_gemm_quantization.cpp) +endif() + set(PROFILER_SOURCES profiler.cpp) foreach(SOURCE ${PROFILER_OPS}) string(REGEX REPLACE "profile_(.+)\.cpp" "\\1" OP_NAME ${SOURCE}) @@ -248,6 +253,10 @@ if(DL_KERNELS) list(APPEND DEVICE_INSTANCES device_grouped_conv3d_bwd_weight_instance) endif() +if(CK_ENABLE_INT8) + list(APPEND DEVICE_INSTANCES device_quantization_instance) +endif() + set(PROFILER_LIBS utility getopt::getopt) foreach(LIB ${DEVICE_INSTANCES}) string(REGEX REPLACE "device_(.+)_instance" "\\1" INSTANCE_NAME ${LIB}) diff --git a/profiler/src/profile_gemm_quantization.cpp b/profiler/src/profile_gemm_quantization.cpp new file mode 100644 index 0000000000..d28dd60dce --- /dev/null +++ b/profiler/src/profile_gemm_quantization.cpp @@ -0,0 +1,115 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include + +#include "profiler/profile_gemm_quantization_impl.hpp" +#include "profiler_operation_registry.hpp" + +#define OP_NAME "gemm_quantization" +#define OP_DESC "GEMM Quantization" + +using INT8 = int8_t; +using INT32 = int32_t; + +int profile_gemm_quantization(int argc, char* argv[]) +{ + enum struct MatrixLayout + { + MK_KN_MN, // 0: + MK_NK_MN, // 1: + KM_KN_MN, // 2: + KM_NK_MN, // 3: + }; + + if(argc != 14) + { + // clang-format off + printf("arg1: tensor operation (" OP_NAME ": " OP_DESC ")\n"); + printf("arg2: matrix layout (0: E[m, n] = A[m, k] * B[k, n];\n"); + printf(" 1: E[m, n] = A[m, k] * B[n, k];\n"); + printf(" 2: E[m, n] = A[k, m] * B[k, n];\n"); + printf(" 3: E[m, n] = A[k, m] * B[n, k])\n"); + printf("arg3: verification (0: no; 1: yes)\n"); + printf("arg4: initialization (0: no init; default: integer value)\n"); + printf("arg5: print tensor value (0: no; 1: yes)\n"); + printf("arg6: time kernel (0=no, 1=yes)\n"); + printf("arg7 to 12: M, N, K, StrideA, StrideB, StrideE\n"); + printf("arg13: requant_scale (float, e.g., 0.03)\n"); + // clang-format on + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const bool time_kernel = std::stoi(argv[6]); + + const int M = std::stoi(argv[7]); + const int N = std::stoi(argv[8]); + const int K = std::stoi(argv[9]); + + const int StrideA = std::stoi(argv[10]); + const int StrideB = std::stoi(argv[11]); + const int StrideE = std::stoi(argv[12]); + + const float requant_scale = std::stof(argv[13]); + + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + + auto profile = [&](auto a_layout, auto b_layout, auto e_layout) { + using ALayout = decltype(a_layout); + using BLayout = decltype(b_layout); + using ELayout = decltype(e_layout); + + bool pass = ck::profiler::profile_gemm_quantization_impl(do_verification, + init_method, + do_log, + time_kernel, + M, + N, + K, + StrideA, + StrideB, + StrideE, + requant_scale); + + return pass ? 0 : 1; + }; + + if(layout == MatrixLayout::MK_KN_MN) + { + return profile(Row{}, Row{}, Row{}); + } + else if(layout == MatrixLayout::MK_NK_MN) + { + return profile(Row{}, Col{}, Row{}); + } + else if(layout == MatrixLayout::KM_KN_MN) + { + return profile(Col{}, Row{}, Row{}); + } + else if(layout == MatrixLayout::KM_NK_MN) + { + return profile(Col{}, Col{}, Row{}); + } + else + { + std::cout << "this layout is not implemented" << std::endl; + return 1; + } +} + +REGISTER_PROFILER_OPERATION(OP_NAME, OP_DESC, profile_gemm_quantization); diff --git a/profiler/src/profile_grouped_conv_fwd.cpp b/profiler/src/profile_grouped_conv_fwd.cpp index a7714b4c73..a8d343405d 100644 --- a/profiler/src/profile_grouped_conv_fwd.cpp +++ b/profiler/src/profile_grouped_conv_fwd.cpp @@ -21,14 +21,15 @@ enum struct ConvLayout enum struct ConvDataType { - F32_F32_F32, // 0 - F16_F16_F16, // 1 - BF16_BF16_BF16, // 2 - INT8_INT8_INT8, // 3 - F8_F8_F8, // 4 - BF8_BF8_F8, // 5 - F8_BF8_F8, // 6 - BF8_F8_F8, // 7 + F32_F32_F32, // 0 + F16_F16_F16, // 1 + BF16_BF16_BF16, // 2 + INT8_INT8_INT8, // 3 + F8_F8_F8, // 4 + BF8_BF8_F8, // 5 + F8_BF8_F8, // 6 + BF8_F8_F8, // 7 + F32_F32_F32_TF32, // 8 }; enum struct IndexType @@ -52,7 +53,8 @@ static void print_helper_msg() << " 4: Input fp8, Weight fp8, Output fp8\n" << " 5: Input bf8, Weight bf8, Output fp8\n" << " 6: Input fp8, Weight bf8, Output fp8\n" - << " 7: Input bf8, Weight fp8, Output fp8)\n" + << " 7: Input bf8, Weight fp8, Output fp8\n" + << " 8: Input fp32, Weight fp32, Output fp32, Compute tf32)\n" << "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n" << " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n" << " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, " @@ -103,6 +105,9 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) using INT8 = int8_t; using F8 = ck::f8_t; using BF8 = ck::bf8_t; +#if defined(__gfx942__) + using TF32 = ck::tf32_t; +#endif // using GNWC = ck::tensor_layout::convolution::GNWC; @@ -261,6 +266,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile( I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, INT8{}, INT8{}, INT8{}, INT8{}, INT8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } // NHWGC_GKYXC_NHWGK else if(num_dim_spatial == 1 && layout == ConvLayout::NHWGC_GKYXC_NHWGK) @@ -367,6 +378,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) { return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, BF8{}, F8{}, F8{}, BF8{}, F8{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } // NGCDHW_GKCZYX_NGKDHW else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKCYX_NGKHW) @@ -384,6 +401,12 @@ int profile_grouped_conv_fwd(int argc, char* argv[]) return profile( I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{}); } + else if(data_type == ConvDataType::F32_F32_F32_TF32) + { +#if defined(__gfx942__) + return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{}); +#endif + } } std::cout << "this data_type & layout is not implemented" << std::endl; diff --git a/script/cmake-ck-dev.sh b/script/cmake-ck-dev.sh index 086359a79f..6220009b03 100755 --- a/script/cmake-ck-dev.sh +++ b/script/cmake-ck-dev.sh @@ -20,7 +20,7 @@ fi GPU_TARGETS="gfx908;gfx90a;gfx942" if [ $# -ge 1 ]; then - case "$1" in + case "$1" in gfx*) GPU_TARGETS=$1 shift 1 @@ -38,7 +38,7 @@ fi cmake \ -D CMAKE_PREFIX_PATH=/opt/rocm/ \ -D CMAKE_CXX_COMPILER=/opt/rocm/llvm/bin/clang++ \ --D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker" \ +-D CMAKE_CXX_FLAGS="-ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker -fbracket-depth=512" \ -D CMAKE_BUILD_TYPE=Release \ -D BUILD_DEV=ON \ -D GPU_TARGETS=$GPU_TARGETS \ diff --git a/script/launch_tests.sh b/script/launch_tests.sh index 5e71e25478..17a99e62a3 100755 --- a/script/launch_tests.sh +++ b/script/launch_tests.sh @@ -49,7 +49,7 @@ with open('$TEST_FILE', 'r') as f: if tests: # Extract just the filename after the last '/' clean_tests = [os.path.basename(test) for test in tests] - print('ctest -R \"' + '|'.join(clean_tests) + '\"') + print('ctest --output-on-failure -R \"' + '|'.join(clean_tests) + '\"') else: print('# No tests to run') ") @@ -57,5 +57,3 @@ with open('$TEST_FILE', 'r') as f: echo "$command" eval "$command" - - diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index f898f67685..cedac568db 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -277,6 +277,7 @@ add_subdirectory(conv_tensor_rearrange) add_subdirectory(transpose) add_subdirectory(permute_scale) add_subdirectory(wrapper) +add_subdirectory(quantization) if(SUPPORTED_GPU_TARGETS MATCHES "gfx11") add_subdirectory(wmma_op) endif() diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 993df2ec40..b08f0d8316 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -3,7 +3,10 @@ add_subdirectory(gemm) add_subdirectory(gemm_weight_preshuffle) add_subdirectory(batched_gemm) add_subdirectory(grouped_gemm) +add_subdirectory(grouped_gemm_preshuffle) add_subdirectory(gemm_multi_d) +add_subdirectory(gemm_multi_abd) +add_subdirectory(gemm_streamk) add_subdirectory(data_type) add_subdirectory(container) add_subdirectory(elementwise) diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp index dd90034064..d997596414 100644 --- a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -58,7 +58,7 @@ float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index b2f965764d..8f24c9bfe1 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -23,13 +23,17 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) std::vector gemmParams{{256, 256, 256, 1}, {256, 256, 256, 2}, {256, 256, 512, 2}, - {256, 256, 128, 2}, {256, 256, 64, 2}, {256, 256, 64, 3}, {256, 256, 64, 4}, {256, 256, 64, 8}, {256, 256, 64, 16}}; + if(ck_tile::get_device_name() != "gfx950") + { + gemmParams.emplace_back(256, 256, 128, 2); + } + for(auto& params : gemmParams) { std::vector strideConfigs{{params.K, diff --git a/test/ck_tile/data_type/test_pk_fp4.cpp b/test/ck_tile/data_type/test_pk_fp4.cpp index 15f027e95d..b1e981624a 100644 --- a/test/ck_tile/data_type/test_pk_fp4.cpp +++ b/test/ck_tile/data_type/test_pk_fp4.cpp @@ -2,6 +2,7 @@ // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "gtest/gtest.h" +#include #include #include "ck_tile/core.hpp" @@ -29,6 +30,12 @@ TEST(PackedFp4, NumericLimits) EXPECT_EQ(ck_tile::numeric::epsilon(), pk_fp4_t{0b00010001}); EXPECT_EQ(ck_tile::numeric::round_error(), pk_fp4_t{0b00010001}); } +TEST(PackedFp4, fill) +{ + std::vector v_fp4(4); + ck_tile::FillUniformDistribution{1.f, 1.f}(v_fp4); + EXPECT_EQ(v_fp4[0].get(), pk_fp4_t{0b00100010}.get()); +} TEST(PackedFp4, ConvertBasic) { EXPECT_EQ(ck_tile::convert_to_type(0.0f), pk_fp4_t{0b00000000}.get()); @@ -102,7 +109,7 @@ struct SrcPkfp4Dst // ex: fp32_t -> fp4 -> bf16_t dst[i] = toDST(toPF4(src[i])); // ex: fp32x2_t -> pk_fp4 -> unpack<0> -> bf16_t - dst[i + 1] = toDST(toPF4(toPF4(input2).unpack(number<1>{}))); + dst[i + 1] = toDST(toPF4(input2).unpack(number<1>{})); } else { diff --git a/test/ck_tile/fmha/test_fmha_fwd.inc b/test/ck_tile/fmha/test_fmha_fwd.inc index f02ef1e55e..08abd3358d 100644 --- a/test/ck_tile/fmha/test_fmha_fwd.inc +++ b/test/ck_tile/fmha/test_fmha_fwd.inc @@ -32,9 +32,6 @@ const ck_tile::stream_config stream_config{ 1, // rotating_count_ }; -// range_q, range_k, range_v, range_p, range_o, squant -#define QUANT_ARGS 1, 1, 1, 1, 1, squant - #define COMMON_ARGS \ init_method, static_cast(ck_tile::EnvValue(CK_TILE_ENV(CK_TILE_TEST_SEED))), 1, \ stream_config @@ -117,7 +114,7 @@ TEST_P(AllLong, Test) 1024, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -179,7 +176,7 @@ TEST_P(HDimPadding, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -236,7 +233,7 @@ TEST_P(ElementwiseBias, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -292,7 +289,7 @@ TEST_P(Alibi, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -350,7 +347,7 @@ TEST_P(Dropout, Test) drop_offset, // drop_offset drop_prefs, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -410,7 +407,7 @@ TEST_P(PagedKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -476,7 +473,7 @@ TEST_P(SplitKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, true, // is_rotary_interleaved num_splits, // num_splits COMMON_ARGS); @@ -548,7 +545,7 @@ TEST_P(AppendKV, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, false, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); @@ -618,7 +615,7 @@ TEST_P(AppendKVRoPE, Test) 0, // drop_offset false, // drop_prefs mask_str, // mask_str - QUANT_ARGS, + squant, is_rotary_interleaved, // is_rotary_interleaved 1, // num_splits COMMON_ARGS); diff --git a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp index 46ed8f4125..b99c304d1f 100644 --- a/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp +++ b/test/ck_tile/fmha/test_fmha_fwd_fp8.cpp @@ -17,22 +17,21 @@ using DataTypeConfig = FmhaFwdFp8; // instances are added), however the corresponding tests are not disabled (they will be skipped) // in case such instances will be added in the future. -const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto HDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); -const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto SplitKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); -const auto AppendKVHDimValues = - Values(std::tuple{64, -1}, std::tuple{128, -1}, std::tuple{256, -1}); +const auto AppendKVHDimValues = Values(std::tuple{64, -1}, std::tuple{128, -1}); // There are no fp8 instances with seqlen padding (mode_enum::group requires it) const auto ModeValues = Values(mode_enum::batch); const auto IsVRowmajorValues = Values(false); -const bool squant = true; -const std::string init_method = "ufq"; +const auto squant = true; +const std::string init_method = "uf"; const bool def_lse = false; -const bool def_is_v_rowmajor = false; +const bool def_is_v_rowmajor = true; int adjust_seqlen(int seqlen) { diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index ab74e4e7b1..57feefceab 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -2,6 +2,8 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck_tile/host/permute_pk_int4.hpp" + template static constexpr inline auto is_row_major(Layout layout_) { @@ -91,61 +93,6 @@ void permute_tensor_b(Tensor& tensor) } } -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - template + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_abd_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue enabled + // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::true_type>, + + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, F8, F8, F32, F32, AddScale, AddScale, MultiplyMultiply, std::true_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); + +#include "test_gemm_multi_abd_ut_cases_cshuffle.inc" diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp new file mode 100644 index 0000000000..b3a89aba05 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_default2d.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_gemm_multi_abd_util.hpp" + +using F16 = ck_tile::half_t; +using BF16 = ck_tile::bf16_t; +using F32 = float; +using F8 = ck_tile::fp8_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypes = ::testing::Types< + // Has cshuffle epilogue disabled + // A0Layout, A1Layout, B0Layout, B1Layout CLayout, D0Layout, D1Layout, A0DataType, A01DataType B0DataType, B0DataType, D0DataType, D1DataType, AccDataType, EDataType, AElementWiseFn, BElementWiseFn, CDElementWiseFn, UseCshuffleEpilog + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, ElementWiseAddAdd, std::false_type>, + + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F16, F16, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, F32, F32, F32, F32, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F16, F16, F16, F16, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type>, + std::tuple< Row, Row, Col, Col, Row, Row, Row, F8, F8, F8, F8, BF16, BF16, F32, BF16, AddScale, AddScale, MultiplyMultiply, std::false_type> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGemmMultiABD, KernelTypes); + +#include "test_gemm_multi_abd_ut_cases_default2d.inc" diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc new file mode 100644 index 0000000000..5aa113608f --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_cshuffle.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2CShuffle_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc new file mode 100644 index 0000000000..cc7603164c --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_ut_cases_default2d.inc @@ -0,0 +1,211 @@ +#pragma once + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch1Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 256; + constexpr int kBatch = 1; + + EXPECT_EQ(this->Run(M, N, K, kBatch), true); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x512) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x512x256) +{ + constexpr int M = 256; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x256x256) +{ + constexpr int M = 512; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x512x256) +{ + constexpr int M = 512; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x256x256) +{ + constexpr int M = 256; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x768x256) +{ + constexpr int M = 512; + constexpr int N = 768; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_512x1280x256) +{ + constexpr int M = 512; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_256x1280x256) +{ + constexpr int M = 256; + constexpr int N = 1280; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_768x512x256) +{ + constexpr int M = 768; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x512x256) +{ + constexpr int M = 1280; + constexpr int N = 512; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} + +TYPED_TEST(TestCkTileGemmMultiABD, TestCkTileGemmMultiABDKBatch2Default_1280x256x256) +{ + constexpr int M = 1280; + constexpr int N = 256; + constexpr int K = 512; + constexpr int kBatch = 2; + + EXPECT_THROW(this->Run(M, N, K, kBatch), std::runtime_error); +} diff --git a/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp new file mode 100644 index 0000000000..428bed4e25 --- /dev/null +++ b/test/ck_tile/gemm_multi_abd/test_gemm_multi_abd_util.hpp @@ -0,0 +1,500 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/gemm_multi_abd_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +struct AddScale +{ + template + CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const A0& a0, const A1& a1) const + { + a = scale * (ck_tile::type_convert(a0) + ck_tile::type_convert(a1)); + } + + float scale = 1.0; +}; + +struct MultiplyMultiply +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) * ck_tile::type_convert(d0) * + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +struct ElementWiseAddAdd +{ + template + CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const D0& d0, const D1& d1) const -> void + { + const float x0_f = ck_tile::type_convert(c) + ck_tile::type_convert(d0) + + ck_tile::type_convert(d1); + + e = ck_tile::type_convert(x0_f); + } +}; + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeTypeAB = + std::conditional_t; + + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileGemmMultiABD : public ::testing::Test +{ + protected: + using A0Layout = std::tuple_element_t<0, Tuple>; + using A1Layout = std::tuple_element_t<1, Tuple>; + using B0Layout = std::tuple_element_t<2, Tuple>; + using B1Layout = std::tuple_element_t<3, Tuple>; + using D0Layout = std::tuple_element_t<4, Tuple>; + using D1Layout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + using A0DataType = std::tuple_element_t<7, Tuple>; + using A1DataType = std::tuple_element_t<8, Tuple>; + using B0DataType = std::tuple_element_t<9, Tuple>; + using B1DataType = std::tuple_element_t<10, Tuple>; + using D0DataType = std::tuple_element_t<11, Tuple>; + using D1DataType = std::tuple_element_t<12, Tuple>; + using AccDataType = std::tuple_element_t<13, Tuple>; + using EDataType = std::tuple_element_t<14, Tuple>; + using AElementWiseFn = std::tuple_element_t<15, Tuple>; + using BElementWiseFn = std::tuple_element_t<16, Tuple>; + using CDElementWiseFn = std::tuple_element_t<17, Tuple>; + using UseCshuffleEpilog = std::tuple_element_t<18, Tuple>; + + using AsLayout = ck_tile::tuple; + using AsDataType = ck_tile::tuple; + using BsLayout = ck_tile::tuple; + using BsDataType = ck_tile::tuple; + using DsLayout = ck_tile::tuple; + using DsDataType = ck_tile::tuple; + + template + void invoke_gemm_multi_abd(const ck_tile::GemmMultiABDHostArgs& args, + const ck_tile::stream_config& s) + { + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool DoubleSmemBuffer = false; + + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr bool TransposeC = false; + + constexpr int kBlockPerCu = 1; + constexpr ck_tile::index_t TileParitionerGroupNum = 8; + constexpr ck_tile::index_t TileParitionerM01 = 4; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t k_grain = args.k_batch * K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; + + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + + using GemmEpilogue = std:: + conditional_t; + + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + std::cout << "Run without SplitK" << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + std::cout << "Run using SplitK" << std::endl; + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + bool Run(const int M, + const int N, + const int K, + const int k_batch, + int StrideA0 = 0, + int StrideA1 = 0, + int StrideB0 = 0, + int StrideB1 = 0, + int StrideD0 = 0, + int StrideD1 = 0, + int StrideE = 0) + { + using namespace ck_tile::literals; + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + StrideA0 = f_get_default_stride(M, K, StrideA0, A0Layout{}); + StrideA1 = f_get_default_stride(M, K, StrideA1, A1Layout{}); + + StrideB0 = f_get_default_stride(K, N, StrideB0, B0Layout{}); + StrideB1 = f_get_default_stride(K, N, StrideB1, B1Layout{}); + + StrideD0 = f_get_default_stride(M, N, StrideD0, D0Layout{}); + StrideD1 = f_get_default_stride(M, N, StrideD1, D1Layout{}); + + StrideE = f_get_default_stride(M, N, StrideE, ELayout{}); + + ck_tile::HostTensor a0_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA0, A0Layout{})); + ck_tile::HostTensor a1_m_k_tesnor( + f_host_tensor_descriptor(M, K, StrideA1, A1Layout{})); + + ck_tile::HostTensor b0_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB0, B0Layout{})); + ck_tile::HostTensor b1_k_n_tensors( + f_host_tensor_descriptor(K, N, StrideB1, B1Layout{})); + + ck_tile::HostTensor d0_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD0, D0Layout{})); + ck_tile::HostTensor d1_m_n_tensors( + f_host_tensor_descriptor(M, N, StrideD1, D1Layout{})); + + ck_tile::HostTensor e_m_n_device_result( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a0_m_k_tesnor); + ck_tile::FillUniformDistribution{-1.f, 1.f}(a1_m_k_tesnor); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(b0_k_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b1_k_n_tensors); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(d0_m_n_tensors); + ck_tile::FillUniformDistribution{-1.f, 1.f}(d1_m_n_tensors); + + ck_tile::DeviceMem a0_m_k_dev_buf(a0_m_k_tesnor.get_element_space_size_in_bytes()); + ck_tile::DeviceMem a1_m_k_dev_buf(a1_m_k_tesnor.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem b0_k_n_dev_buf(b0_k_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b1_k_n_dev_buf(b1_k_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem d0_m_n_dev_buf(d0_m_n_tensors.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d1_m_n_dev_buf(d1_m_n_tensors.get_element_space_size_in_bytes()); + + ck_tile::DeviceMem e_m_n_dev_buf(e_m_n_device_result.get_element_space_size_in_bytes()); + + a0_m_k_dev_buf.ToDevice(a0_m_k_tesnor.mData.data()); + a1_m_k_dev_buf.ToDevice(a1_m_k_tesnor.mData.data()); + + b0_k_n_dev_buf.ToDevice(b0_k_n_tensors.mData.data()); + b1_k_n_dev_buf.ToDevice(b1_k_n_tensors.mData.data()); + + d0_m_n_dev_buf.ToDevice(d0_m_n_tensors.mData.data()); + d1_m_n_dev_buf.ToDevice(d1_m_n_tensors.mData.data()); + + e_m_n_dev_buf.SetZero(); + e_m_n_device_result.SetZero(); + + std::array as_ptr_buf = {a0_m_k_dev_buf.GetDeviceBuffer(), + a1_m_k_dev_buf.GetDeviceBuffer()}; + + std::array bs_ptr_buf = {b0_k_n_dev_buf.GetDeviceBuffer(), + b1_k_n_dev_buf.GetDeviceBuffer()}; + + std::array ds_ptr_buf = {d0_m_n_dev_buf.GetDeviceBuffer(), + d1_m_n_dev_buf.GetDeviceBuffer()}; + + std::array strideAs = {StrideA0, StrideA1}; + std::array strideBs = {StrideB0, StrideB1}; + std::array strideDs = {StrideD0, StrideD1}; + + ck_tile::GemmMultiABDHostArgs + args({as_ptr_buf, + bs_ptr_buf, + ds_ptr_buf, + e_m_n_dev_buf.GetDeviceBuffer(), + k_batch, + M, + N, + K, + strideAs, + strideBs, + strideDs, + StrideE}); + + invoke_gemm_multi_abd(args, ck_tile::stream_config{nullptr, false}); + + std::cout << "Run kernel with M =" << M << " N =" << N << " K =" << K + << " StrideA0 =" << StrideA0 << " StrideA1 =" << StrideA1 + << " StrideB0 =" << StrideB0 << " StrideB1 =" << StrideB1 + << " StrideE =" << StrideE << " StrideD0 =" << StrideD0 + << " StrideD1 =" << StrideD1 << std::endl; + + e_m_n_dev_buf.FromDevice(e_m_n_device_result.data()); + bool pass = true; + + ck_tile::HostTensor a_m_k_host_ref_element_result( + f_host_tensor_descriptor(M, K, StrideA0, A0Layout{})); + ck_tile::HostTensor b_k_n_host_ref_element_result( + f_host_tensor_descriptor(K, N, StrideB0, B0Layout{})); + ck_tile::HostTensor e_m_n_host_ref( + f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + a_m_k_host_ref_element_result.SetZero(); + b_k_n_host_ref_element_result.SetZero(); + e_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm_multiple_abd({a0_m_k_tesnor, a1_m_k_tesnor}, + {b0_k_n_tensors, b1_k_n_tensors}, + {d0_m_n_tensors, d1_m_n_tensors}, + a_m_k_host_ref_element_result, + b_k_n_host_ref_element_result, + e_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(e_m_n_host_ref.mData.begin(), e_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + K, k_batch, max_accumulated_value); + pass = ck_tile::check_err(e_m_n_device_result, + e_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + + return pass; + } +}; diff --git a/test/ck_tile/gemm_streamk/CMakeLists.txt b/test/ck_tile/gemm_streamk/CMakeLists.txt new file mode 100644 index 0000000000..e00874ba07 --- /dev/null +++ b/test/ck_tile/gemm_streamk/CMakeLists.txt @@ -0,0 +1,7 @@ +# Currently test_ck_tile_streamk is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + #TODO: support all arches + add_gtest_executable(test_ck_tile_streamk test_gemm_streamk.cpp) +else() + message(DEBUG "Skipping test_ck_tile_streamk tests for current target") +endif() diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp new file mode 100644 index 0000000000..99c3fb397f --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk.cpp @@ -0,0 +1,14 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "test_gemm_streamk_types.hpp" +#include "test_gemm_streamk_util.hpp" +#include "gtest/gtest.h" + +#define TEST_SUITE_NAME TestCkTileStreamK + +TYPED_TEST_SUITE(TestCkTileStreamK, KernelTypesStreamK); + +#include "test_gemm_streamk_cases.inc" + +#undef TEST_SUITE_NAME diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc new file mode 100644 index 0000000000..1db7ef0fb0 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_cases.inc @@ -0,0 +1,118 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_DP) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks4) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 4; + + this->Run(M, N, K, num_sk_blocks); +} + +// TODO: Renable this test once reduction is implemented +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks12) +{ + GTEST_SKIP() << "Skipping this test: There are precision issues with atomics due to >=3 WGs " + "contributing to each macro tile in C"; + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 12; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M256_N256_K256_SKBlocks8) +{ + + ck_tile::index_t M = 256; + ck_tile::index_t N = 256; + ck_tile::index_t K = 256; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_DP) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks16) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 16; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M512_N512_K512_SKBlocks8) +{ + + ck_tile::index_t M = 512; + ck_tile::index_t N = 512; + ck_tile::index_t K = 512; + uint32_t num_sk_blocks = 8; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_DP) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 0; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_M3840_N4096_K4096_SKBlocks64) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + this->Run(M, N, K, num_sk_blocks); +} + +TYPED_TEST(TEST_SUITE_NAME, StreamK_Unsupported_Reduction) +{ + + ck_tile::index_t M = 3840; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + uint32_t num_sk_blocks = 64; + + EXPECT_THROW(this->Run(M, N, K, num_sk_blocks, ck_tile::StreamKReductionStrategy::Reduction), + std::runtime_error); +} diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp new file mode 100644 index 0000000000..399f3f11e8 --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_types.hpp @@ -0,0 +1,25 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" + +using F16 = ck_tile::half_t; +using F32 = float; +using BF16 = ck_tile::bf16_t; + +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// clang-format off +using KernelTypesStreamK = ::testing::Types< +// ALayout BLayout CLayout ADataType BDataType AccDataType CDataType + std::tuple< Row, Col, Row, F16, F16, F32, F16>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16> +>; + +// clang-format on diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp new file mode 100644 index 0000000000..b8a55b024d --- /dev/null +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -0,0 +1,282 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + + // The logic below may need to become more advanced once bugs in Stream-K Tile Partitioner are + // resolved. Because the number of WGs contributing to a macro tile in C may not be the same for + // all macro tiles in C. + + // Calculate error due to more than 1 WG contributing to the same macro tile in C + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +class TestCkTileStreamK : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using AccDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + using DsLayout = ck_tile::tuple<>; + using DsDataType = ck_tile::tuple<>; + + template + void invoke_streamk(const ck_tile::StreamKHostArgs& args, + const ck_tile::stream_config& s, + int num_cu, + int occupancy) + { + + constexpr ck_tile::index_t M_Tile = 128; + constexpr ck_tile::index_t N_Tile = 128; + constexpr ck_tile::index_t K_Tile = 32; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + constexpr bool kPadM = PadM; + constexpr bool kPadN = PadN; + constexpr bool kPadK = PadK; + constexpr bool preshuffle = Preshuffle; + + constexpr bool DoubleSmemBuffer = false; + constexpr int kBlockPerCu = 1; + constexpr bool StructuredSparsity = false; + constexpr bool NumWaveGroup = 1; + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::StreamKTilePartitioner; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC, + memory_operation>>; + + using Kernel = ck_tile::StreamKKernel; + + auto kargs = Kernel::MakeKernelArgs(args, num_cu, occupancy); + + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } + + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + }; + + Run(ck_tile::integral_constant{}); + } + + public: + // Since Stream-K is build on gfx9, the lower bound for CUs is 104. Thus, we default num_cu to + // 104 and occupancy to 1 to ensure tests are reproducible on different architectures. + void Run(ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + uint32_t num_sk_blocks = 0xffffffff, + ck_tile::StreamKReductionStrategy reduction_strategy = + ck_tile::StreamKReductionStrategy::Atomic, + int occupancy = 1, + int num_cu = 104, + ck_tile::index_t stride_A = 0, + ck_tile::index_t stride_B = 0, + ck_tile::index_t stride_C = 0) + { + + using namespace ck_tile::literals; + + if(reduction_strategy == ck_tile::StreamKReductionStrategy::Reduction) + { + throw std::runtime_error("Reduction Strategy is current unsupported!\n"); + } + + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + stride_A = f_get_default_stride(M, K, stride_A, ALayout{}); + stride_B = f_get_default_stride(K, N, stride_B, BLayout{}); + stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); + + ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, ALayout{})); + ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, BLayout{})); + ck_tile::HostTensor c_m_n_dev_result( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-5, 5, /*seed*/ 11940}(b_k_n); + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + ck_tile::StreamKHostArgs args{a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + c_m_n_dev_buf.GetDeviceBuffer(), + M, + N, + K, + stride_A, + stride_B, + stride_C, + reduction_strategy, + num_sk_blocks}; + + invoke_streamk( + args, ck_tile::stream_config{nullptr, false, 0, 0, 1}, num_cu, occupancy); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, /*kbatch*/ 1, max_accumulated_value); + + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass); + }; +}; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp index b1521fc35a..ed1b1e32ab 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_kernel_types.hpp @@ -13,6 +13,7 @@ using F16 = ck_tile::half_t; using F32 = float; using F8 = ck_tile::fp8_t; using BF16 = ck_tile::bf16_t; +using I4 = ck_tile::pk_int4_t; using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; @@ -20,20 +21,24 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Default = ck_tile::integral_constant; -using WeightPreshuffle = - ck_tile::integral_constant; - -// Adding alias for the F8 parameters to facilitate skipping tests. -// This alias can be removed once test failures are fixed. -using F8Types = std::tuple; +using WeightPreshuffleV1 = + ck_tile::integral_constant; +using WeightPreshuffleV2 = + ck_tile::integral_constant; // clang-format off using KernelTypesWeightPreshuffle = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle> -#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 - , F8Types + std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV1>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV1> +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + , + std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV1>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV1> #endif >; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc index 389e0d53ea..bb56c63413 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc @@ -20,7 +20,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -48,7 +48,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -77,7 +77,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -106,7 +106,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 42d0149498..62f819ac1e 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" @@ -34,20 +35,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K, enum struct GemmPipelineType { - WeightPreshuffle + WeightPreshuffleV1, + WeightPreshuffleV2 }; template struct GemmPipelineTypeSelector; template -struct GemmPipelineTypeSelector +struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV1"; } }; + +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV2"; } +}; + template struct config { @@ -122,7 +134,8 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = false; + constexpr bool DoubleSmemBuffer = + (PipelineType == GemmPipelineType::WeightPreshuffleV2) ? true : false; // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; @@ -391,10 +404,19 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_shuffle_host_dev = b_shuffle_host; + ck_tile::permute_vectors_i4x4_b(b_shuffle_host_dev); + b_k_n_dev_buf.ToDevice(b_shuffle_host_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + } c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); diff --git a/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..68120efc7e --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/CMakeLists.txt @@ -0,0 +1,9 @@ +set(EXAMPLE_GEMM_COMPILE_OPTIONS) +if(CK_USE_OCP_FP8) + list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) +endif() + +if(GPU_TARGETS MATCHES "gfx94|gfx95") + add_gtest_executable(test_ck_tile_grouped_gemm_preshuffle test_grouped_gemm_preshuffle.cpp) + target_compile_options(test_ck_tile_grouped_gemm_preshuffle PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +endif() diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp new file mode 100644 index 0000000000..cf10853b3f --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle.cpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" + +#include "ck_tile/host.hpp" +#include "test_grouped_gemm_preshuffle_util.hpp" + +using F16 = ck_tile::half_t; +using F8 = ck_tile::fp8_t; +using F32 = float; +using Row = ck_tile::tensor_layout::gemm::RowMajor; +using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + +// Custom tuple-like structure for kernel configuration +template +struct KernelConfig +{ + using ALayoutType = ALayout_; + using BLayoutType = BLayout_; + using CLayoutType = CLayout_; + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + static constexpr int M_Tile_ = M_Tile_val_; + static constexpr int N_Tile_ = N_Tile_val_; + static constexpr int K_Tile_ = K_Tile_val_; + static constexpr int BlockPerCu_ = BlockPerCu_val_; +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>, + KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>, + KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2> + >; +// clang-format on + +TYPED_TEST_SUITE(TestCkTileGroupedGemmPreshuffle, KernelTypes); + +#include "test_grouped_gemm_preshuffle_ut_cases.inc" +#include "test_grouped_gemm_preshuffle_prefill_cases.inc" diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc new file mode 100644 index 0000000000..340d807ba2 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_prefill_cases.inc @@ -0,0 +1,61 @@ +#pragma once + +// Test with prefill config struct +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, PrefillVariant) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + + Ms.push_back(256 + 128 * i); + Ns.push_back(256 + 128 * i); + Ks.push_back(128 * (i + 1)); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} + +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, VariedDimensions) +{ + const int group_count = 6; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + std::vector> test_cases = {{64, 128, 256}, + {128, 256, 512}, + {256, 512, 1024}, + {512, 256, 128}, + {128, 128, 128}, + {64, 512, 256}}; + + for(int i = 0; i < group_count; i++) + { + auto [M, N, K] = test_cases[i]; + Ms.push_back(M); + Ns.push_back(N); + Ks.push_back(K); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc new file mode 100644 index 0000000000..beca5e62b5 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_ut_cases.inc @@ -0,0 +1,53 @@ +#pragma once + +// kPadK is not needed for these k values +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKFalse) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 256 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} + +// kPadK is needed to be true for these k values +TYPED_TEST(TestCkTileGroupedGemmPreshuffle, kPadKTrue) +{ + const int group_count = 4; + const int kbatch = 1; + std::vector Ms; + std::vector Ns; + std::vector Ks; + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + for(int i = 0; i < group_count; i++) + { + Ms.push_back(256 + 256 * i); + Ns.push_back(256 + 512 * i); + Ks.push_back(512 + 128 * i); + + stride_As.push_back(Ks[i]); + stride_Bs.push_back(Ks[i]); + stride_Cs.push_back(Ns[i]); + } + + this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, kbatch, group_count); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp new file mode 100644 index 0000000000..799a5f2907 --- /dev/null +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -0,0 +1,374 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +template +constexpr ck_tile::index_t get_k_warp_tile_flatmm() +{ +#if defined(CK_GFX950_SUPPORT) + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 64; + else + return sizeof(PrecType) == 2 ? 32 : 128; +#else + if constexpr(M_Warp_Tile == 32) + return sizeof(PrecType) == 2 ? 16 : 32; + else + return sizeof(PrecType) == 2 ? 32 : 64; +#endif +} + +template +class TestCkTileGroupedGemmPreshuffle : public ::testing::Test +{ + protected: + using ALayout = typename Tuple::ALayoutType; + using BLayout = typename Tuple::BLayoutType; + using CLayout = typename Tuple::CLayoutType; + using ADataType = typename Tuple::ADataType; + using BDataType = typename Tuple::BDataType; + using AccDataType = typename Tuple::AccDataType; + using CDataType = typename Tuple::CDataType; + using PrecType = BDataType; + using DsLayout = ck_tile::tuple<>; // not used + using DsDataType = ck_tile::tuple<>; // not used + + static const bool kPadM = false; + static const bool kPadN = false; + static const bool kPadK = true; // preshuffle pipeline requires k padding + + static const int kBlockPerCu = Tuple::BlockPerCu_; + + // Tile dimensions from tuple + static const ck_tile::index_t M_Tile = Tuple::M_Tile_; + static const ck_tile::index_t N_Tile = Tuple::N_Tile_; + static const ck_tile::index_t K_Tile = Tuple::K_Tile_; + + static const ck_tile::index_t M_Warp = 1; + static const ck_tile::index_t N_Warp = 4; + static const ck_tile::index_t K_Warp = 1; + + static const ck_tile::index_t M_Warp_Tile = 16; + static const ck_tile::index_t N_Warp_Tile = 16; + static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + + static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem + static constexpr bool TransposeC = false; // transpose c is not supported + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + + template + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } + + using grouped_gemm_kargs = ck_tile::GroupedGemmHostArgs; + inline std::size_t get_workspace_size(const std::vector& gemm_descs) + { + return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg); + } + + template + auto shuffle_b(const ck_tile::HostTensor& t) + { + assert(t.get_lengths().size() == 2); + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + constexpr int divisor = N_Warp_Tile == 32 ? 2 : 4; + ck_tile::HostTensor t_view( + {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); + } + + template + void invoke_grouped_gemm(const std::vector& gemm_descs, + const ck_tile::stream_config& s, + void* kargs_ptr) + { + + using GemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + using TilePartitioner = ck_tile:: + GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + // for testing purposes, we can hardcode the values here as we what is compatible with + // pipeline + using GemmUniversalTraits = + ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = + ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + + const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile; + const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile; + const ck_tile::index_t num_loop = + ck_tile::GemmSpatiallyLocalTilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = [&](const auto has_hot_loop_, + const auto tail_number_, + const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = + ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(gemm_descs[0].k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + // EXPECT TO FAIL because splitk is not supported + EXPECT_FALSE(true); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + } + + public: + void Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + std::vector& stride_As, + std::vector& stride_Bs, + std::vector& stride_Cs, + const int kbatch = 1, + const int group_count = 16) + { + + using namespace ck_tile::literals; + auto f_host_tensor_descriptor = [](std::size_t row, + std::size_t col, + std::size_t stride, + auto layout) { + if constexpr(std::is_same_v) + { + return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + auto f_get_default_stride = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + if(stride == 0) + { + if constexpr(std::is_same_v) + { + return col; + } + else + { + return row; + } + } + else + return stride; + }; + + std::vector> a_m_k_tensors; + std::vector> b_k_n_tensors; + std::vector> c_m_n_tensors; + + a_m_k_tensors.reserve(group_count); + b_k_n_tensors.reserve(group_count); + c_m_n_tensors.reserve(group_count); + + std::vector> a_m_k_dev_buf; + std::vector> b_k_n_dev_buf; + std::vector> c_m_n_dev_buf; + + a_m_k_dev_buf.reserve(group_count); + b_k_n_dev_buf.reserve(group_count); + c_m_n_dev_buf.reserve(group_count); + + std::vector gemm_descs; + gemm_descs.reserve(group_count); + + for(int i = 0; i < group_count; ++i) + { + const ck_tile::index_t M = Ms[i]; + const ck_tile::index_t N = Ns[i]; + const ck_tile::index_t K = Ks[i]; + + stride_As[i] = f_get_default_stride(M, K, stride_As[i], ALayout{}); + stride_Bs[i] = f_get_default_stride(K, N, stride_Bs[i], BLayout{}); + stride_Cs[i] = f_get_default_stride(M, N, stride_Cs[i], CLayout{}); + + a_m_k_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, K, stride_As[i], ALayout{}))); + b_k_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(K, N, stride_Bs[i], BLayout{}))); + c_m_n_tensors.push_back(ck_tile::HostTensor( + f_host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + + ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensors[i]); + ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); + + // Host-side preshuffle of B + auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + + a_m_k_dev_buf.push_back(std::make_unique( + a_m_k_tensors[i].get_element_space_size_in_bytes())); + b_k_n_dev_buf.push_back(std::make_unique( + b_shuffle_host.get_element_space_size_in_bytes())); + c_m_n_dev_buf.push_back(std::make_unique( + c_m_n_tensors[i].get_element_space_size_in_bytes())); + + a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data()); + b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data()); + c_m_n_dev_buf[i]->SetZero(); + c_m_n_tensors[i].SetZero(); + + const void* p_a = a_m_k_dev_buf[i]->GetDeviceBuffer(); + const void* p_b = b_k_n_dev_buf[i]->GetDeviceBuffer(); + void* p_c = c_m_n_dev_buf[i]->GetDeviceBuffer(); + + gemm_descs.push_back( + {p_a, p_b, p_c, kbatch, M, N, K, stride_As[i], stride_Bs[i], stride_Cs[i]}); + } + + ck_tile::DeviceMem gemm_workspace; + gemm_workspace.Realloc(get_workspace_size(gemm_descs)); + + invoke_grouped_gemm(gemm_descs, + ck_tile::stream_config{nullptr, false, 1}, + gemm_workspace.GetDeviceBuffer()); + + // Copy results back to host for validation + for(int i = 0; i < group_count; i++) + { + c_m_n_dev_buf[i]->FromDevice(c_m_n_tensors[i].data()); + } + + bool pass{true}; + for(int i = 0; i < group_count; ++i) + { + ck_tile::HostTensor c_m_n_host_ref( + f_host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + c_m_n_host_ref.SetZero(); + ck_tile::reference_gemm( + a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + calculate_rtol_atol( + Ks[i], kbatch, max_accumulated_value); + pass &= ck_tile::check_err(c_m_n_tensors[i], + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + } + EXPECT_TRUE(pass); + } +}; diff --git a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp index f2875c72c8..c6ef822f64 100644 --- a/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp +++ b/test/ck_tile/moe_smoothquant/instances/moe_smoothquant_instance_common.hpp @@ -53,7 +53,7 @@ float moe_smoothquant_(const S& s, A a) using Kernel = ck_tile::MoeSmoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/rmsnorm2d/generate.py b/test/ck_tile/rmsnorm2d/generate.py index 5eded8b310..3bcc427e83 100644 --- a/test/ck_tile/rmsnorm2d/generate.py +++ b/test/ck_tile/rmsnorm2d/generate.py @@ -201,7 +201,7 @@ float rmsnorm2d_fwd_(const S& s, A a) using Kernel = ck_tile::Rmsnorm2dFwd; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp index 8929289cdb..138afcffaf 100644 --- a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -49,7 +49,7 @@ float smoothquant_(const S& s, A a) using Kernel = ck_tile::Smoothquant; const dim3 grids = Kernel::GridSize(a); - constexpr dim3 blocks = Kernel::BlockSize(); + const dim3 blocks = Kernel::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = 1; auto kargs = Kernel::MakeKargs(a); diff --git a/test/data_type/test_bf8_fnuz.cpp b/test/data_type/test_bf8_fnuz.cpp index 4ff796a614..f028c0da73 100644 --- a/test/data_type/test_bf8_fnuz.cpp +++ b/test/data_type/test_bf8_fnuz.cpp @@ -43,9 +43,8 @@ TEST(BF8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity())); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -80,9 +79,8 @@ TEST(BF8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to bf8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity())); // positive norm float value to bf8 and back, check if holds float pos_float = 0.0000762939f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -118,9 +116,8 @@ TEST(BF8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -155,9 +152,8 @@ TEST(BF8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to bf8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to bf8 and back, check if holds half_t pos_half = half_t{0.0000762939}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); diff --git a/test/data_type/test_fp8_fnuz.cpp b/test/data_type/test_fp8_fnuz.cpp index c2ec6dad94..0cf775f947 100644 --- a/test/data_type/test_fp8_fnuz.cpp +++ b/test/data_type/test_fp8_fnuz.cpp @@ -48,9 +48,8 @@ TEST(FP8FNUZ, ConvertFP32Nearest) type_convert(f8_convert_rne(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(std::numeric_limits::infinity())); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_rne(pos_float)), abs_tol); @@ -85,9 +84,8 @@ TEST(FP8FNUZ, ConvertFP32Stochastic) type_convert(f8_convert_sr(std::numeric_limits::max())), abs_tol); // convert inf float to f8_fnuz_t and check if it is qNan - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(std::numeric_limits::infinity()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(std::numeric_limits::infinity())); // positive norm float value to fp8 and back, check if holds float pos_float = 0.017578125f; ASSERT_NEAR(pos_float, type_convert(f8_convert_sr(pos_float)), abs_tol); @@ -122,9 +120,8 @@ TEST(FP8FNUZ, ConvertFP16Nearest) type_convert(f8_convert_rne(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_rne(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_rne(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_rne(pos_half)), abs_tol); @@ -159,9 +156,8 @@ TEST(FP8FNUZ, ConvertFP16Stochastic) type_convert(f8_convert_sr(ck::NumericLimits::Max())), abs_tol); // convert QuietNaN fp16 to f8_fnuz_t and check if it is QuietNaN - ASSERT_NEAR(ck::NumericLimits::QuietNaN(), - f8_convert_sr(ck::NumericLimits::QuietNaN()), - abs_tol); + ASSERT_EQ(ck::NumericLimits::QuietNaN(), + f8_convert_sr(ck::NumericLimits::QuietNaN())); // positive norm fp16 value to fp8 and back, check if holds half_t pos_half = half_t{0.017578125}; ASSERT_NEAR(pos_half, type_convert(f8_convert_sr(pos_half)), abs_tol); diff --git a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp index 2a9421fcd1..354d1fc23b 100644 --- a/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp +++ b/test/grouped_convnd_bwd_weight/test_grouped_convnd_bwd_weight_interface_xdl.cpp @@ -52,7 +52,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test // clang-format on ck::utils::conv::ConvParam conv_param; - std::vector split_ks{-1, 2}; + ck::index_t split_k_ = 2; template bool Run() @@ -96,30 +96,24 @@ class TestGroupedConvndBwdWeight : public ::testing::Test auto conv = GroupedConvBwdWeightDeviceInstance{}; - bool is_supported = true; - - for(const auto split_k : split_ks) - { - auto argument = conv.MakeArgument(nullptr, - nullptr, - nullptr, - input_lengths, - input_strides, - filter_lengths, - weights_strides, - output_lengths, - output_strides, - conv_filter_strides, - conv_filter_dilations, - input_left_pads, - input_right_pads, - PassThrough{}, - PassThrough{}, - PassThrough{}, - split_k); - is_supported &= conv.IsSupportedArgument(argument); - } - return is_supported; + auto argument = conv.MakeArgument(nullptr, + nullptr, + nullptr, + input_lengths, + input_strides, + filter_lengths, + weights_strides, + output_lengths, + output_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + PassThrough{}, + PassThrough{}, + PassThrough{}, + split_k_); + return conv.IsSupportedArgument(argument); } }; @@ -183,3 +177,12 @@ TYPED_TEST(TestGroupedConvndBwdWeightDefault, VectorLoadCheck) is_supported = this->template Run<2>(); EXPECT_FALSE(is_supported); } + +TYPED_TEST(TestGroupedConvndBwdWeightDefault, SingleStageAutoDeduce) +{ + // Supported version but with auto deduce and single stage + this->conv_param = {2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}}; + this->split_k_ = -1; + bool is_supported = this->template Run<2>(); + EXPECT_FALSE(is_supported); +} diff --git a/test/quantization/CMakeLists.txt b/test/quantization/CMakeLists.txt new file mode 100644 index 0000000000..89a99f5e5d --- /dev/null +++ b/test/quantization/CMakeLists.txt @@ -0,0 +1,2 @@ +add_custom_target(test_quantization) +add_subdirectory(gemm) diff --git a/test/quantization/gemm/CMakeLists.txt b/test/quantization/gemm/CMakeLists.txt new file mode 100644 index 0000000000..630e6e09c9 --- /dev/null +++ b/test/quantization/gemm/CMakeLists.txt @@ -0,0 +1,9 @@ +add_custom_target(test_gemm_quantization_targets) + +add_gtest_executable(test_gemm_quantization test_gemm_quantization.cpp) +if(result EQUAL 0) + target_link_libraries(test_gemm_quantization PRIVATE utility device_quantization_instance) + add_dependencies(test_gemm_quantization_targets test_gemm_quantization) +endif() + +add_dependencies(test_quantization test_gemm_quantization_targets) diff --git a/test/quantization/gemm/test_gemm_quantization.cpp b/test/quantization/gemm/test_gemm_quantization.cpp new file mode 100644 index 0000000000..9981ae8a41 --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization.cpp @@ -0,0 +1,40 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_quantization_impl.hpp" +#include "test_gemm_quantization_util.hpp" + +using I8 = int8_t; +using I32 = int32_t; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +class TestGemmQuantization : public ck::test::TestGemmQuantizationCommon +{ + protected: + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); + + ProfileCall GetImpl() override + { + return &ck::profiler::profile_gemm_quantization_impl< + typename ck::test::TestGemmQuantizationCommon::ADataType, + typename ck::test::TestGemmQuantizationCommon::BDataType, + typename ck::test::TestGemmQuantizationCommon::AccDataType, + typename ck::test::TestGemmQuantizationCommon::EDataType, + typename ck::test::TestGemmQuantizationCommon::ALayout, + typename ck::test::TestGemmQuantizationCommon::BLayout, + typename ck::test::TestGemmQuantizationCommon::ELayout>; + } +}; + +using KernelTypes = ::testing::Types, + std::tuple, + std::tuple, + std::tuple>; + +TYPED_TEST_SUITE(TestGemmQuantization, KernelTypes); + +#include "test_gemm_quantization_ut_cases.inc" diff --git a/test/quantization/gemm/test_gemm_quantization_ut_cases.inc b/test/quantization/gemm/test_gemm_quantization_ut_cases.inc new file mode 100644 index 0000000000..83a13e4a85 --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization_ut_cases.inc @@ -0,0 +1,41 @@ +#pragma once + +TYPED_TEST(TestGemmQuantization, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 320; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 1024; + constexpr int K = 320; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + for(int M : Ms) + this->Run({{M, N, K}}); +} + +TYPED_TEST(TestGemmQuantization, Regular) +{ + constexpr int M = 512; + constexpr int N = 512; + std::vector Ks{512}; + + for(int K : Ks) + this->Run({{M, N, K}}); +} diff --git a/test/quantization/gemm/test_gemm_quantization_util.hpp b/test/quantization/gemm/test_gemm_quantization_util.hpp new file mode 100644 index 0000000000..e1ca0de2db --- /dev/null +++ b/test/quantization/gemm/test_gemm_quantization_util.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using I8 = int8_t; +using I32 = int32_t; + +namespace ck { +namespace test { + +using TestMatrixSizes = std::vector>; + +static const TestMatrixSizes DefaultTestMatrixSizes = { + {16, 32, 64}, {512, 2048, 4096}, {2048, 1024, 16}}; + +template +class TestGemmQuantizationCommon : public ::testing::Test +{ + protected: + using ADataType = std::tuple_element_t<0, Tuple>; + using BDataType = std::tuple_element_t<1, Tuple>; + using AccDataType = std::tuple_element_t<2, Tuple>; + using EDataType = std::tuple_element_t<3, Tuple>; + using ALayout = std::tuple_element_t<4, Tuple>; + using BLayout = std::tuple_element_t<5, Tuple>; + using ELayout = std::tuple_element_t<6, Tuple>; + + using ProfileCall = bool (*const)(int, int, bool, bool, int, int, int, int, int, int, float); + + virtual ProfileCall GetImpl() = 0; + + void Run(const TestMatrixSizes& lengths = DefaultTestMatrixSizes) + { + bool all_success = true; + + for(auto length : lengths) + { + int M = length[0]; + int N = length[1]; + int K = length[2]; + int StrideA = ck::is_same_v ? K : M; + int StrideB = ck::is_same_v ? N : K; + int StrideE = ck::is_same_v ? N : M; + float requant_scale = 0.03f; + + all_success = + all_success & + GetImpl()(1, 1, false, true, M, N, K, StrideA, StrideB, StrideE, requant_scale); + } + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/tile_engine/ops/gemm/CMakeLists.txt b/tile_engine/ops/gemm/CMakeLists.txt index d52351af2d..77165ae0fa 100644 --- a/tile_engine/ops/gemm/CMakeLists.txt +++ b/tile_engine/ops/gemm/CMakeLists.txt @@ -13,38 +13,38 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ message(WARNING "Skipping individual GEMM target ${datatype}_${layout}_${trait}_${tile_config}: No supported GPU targets") return() endif() - + # Parse tile configuration: format is tile_mxtile_nxtile_k_warp_mxwarp_nxwarp_k_warp_tile_mxwarp_tile_nxwarp_tile_k # First split by underscore to get three groups string(REPLACE "_" ";" config_groups ${tile_config}) list(GET config_groups 0 tile_dims) # e.g., 256x256x32 list(GET config_groups 1 warp_dims) # e.g., 4x1x1 list(GET config_groups 2 warp_tile_dims) # e.g., 16x16x16 - + # Parse tile dimensions string(REPLACE "x" ";" tile_parts ${tile_dims}) list(GET tile_parts 0 tile_m) list(GET tile_parts 1 tile_n) list(GET tile_parts 2 tile_k) - + # Parse warp dimensions string(REPLACE "x" ";" warp_parts ${warp_dims}) list(GET warp_parts 0 warp_m) list(GET warp_parts 1 warp_n) list(GET warp_parts 2 warp_k) - + # Parse warp tile dimensions string(REPLACE "x" ";" warp_tile_parts ${warp_tile_dims}) list(GET warp_tile_parts 0 warp_tile_m) list(GET warp_tile_parts 1 warp_tile_n) list(GET warp_tile_parts 2 warp_tile_k) - + set(target_name "benchmark_gemm_${datatype}_${layout}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - + # Generate the single instance header for this kernel set(instance_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") - + # Add custom command to generate the header file at build time add_custom_command( OUTPUT ${instance_header} @@ -60,27 +60,27 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ DEPENDS ${GEMM_SOURCE_DIR}/gemm_instance_builder.py ${config_json} COMMENT "Generating ${instance_header}" ) - + # Create the executable - add_executable(${target_name} + add_executable(${target_name} ${GEMM_SOURCE_DIR}/benchmark_gemm_single.cpp ${instance_header} ) - + # Set GPU architectures set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_GPU_TARGETS_INDIVIDUAL}) - + # Set compile definitions target_compile_definitions(${target_name} PRIVATE GEMM_SINGLE_INSTANCE_HPP="${instance_header}" ) - + # Include directories target_include_directories(${target_name} PRIVATE ${GEMM_SOURCE_DIR} ${working_path} ) - + # Compile options target_compile_options(${target_name} PRIVATE -Wno-undefined-func-template @@ -88,19 +88,19 @@ function(create_individual_gemm_target datatype layout trait tile_config config_ --offload-compress -include ${instance_header} ) - + # Add to collection targets add_dependencies(benchmark_gemm_all ${target_name}) add_dependencies(benchmark_gemm_${datatype} ${target_name}) add_dependencies(benchmark_gemm_${layout} ${target_name}) add_dependencies(benchmark_gemm_${datatype}_${layout} ${target_name}) - + # Add to trait-specific targets string(REPLACE "_" ";" trait_parts ${trait}) list(GET trait_parts 0 pipeline) list(GET trait_parts 1 epilogue) list(GET trait_parts 2 scheduler) - + add_dependencies(benchmark_gemm_${pipeline} ${target_name}) add_dependencies(benchmark_gemm_${epilogue} ${target_name}) add_dependencies(benchmark_gemm_${scheduler} ${target_name}) @@ -109,13 +109,13 @@ endfunction() # Function to build individual GEMM targets function(build_individual_gemm_targets datatype layout) set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}") - + # Choose config file # Priority order: # 1. Environment variable GEMM_CONFIG_FILE - # 2. CMake variable GEMM_CONFIG_FILE + # 2. CMake variable GEMM_CONFIG_FILE # 3. Default based on layout - + # Check environment variable first if(DEFINED ENV{GEMM_CONFIG_FILE} AND NOT "$ENV{GEMM_CONFIG_FILE}" STREQUAL "") set(config_filename "$ENV{GEMM_CONFIG_FILE}") @@ -130,12 +130,12 @@ function(build_individual_gemm_targets datatype layout) set(json_blob "${CMAKE_CURRENT_LIST_DIR}/configs/default_config.json") message(STATUS " Using default config for layout ${layout}") endif() - + # Check if config file exists if(NOT EXISTS ${json_blob}) message(FATAL_ERROR "Config file not found: ${json_blob}") endif() - + # Determine number of workers for parallel generation if(DEFINED ENV{CMAKE_BUILD_PARALLEL_LEVEL}) set(num_workers $ENV{CMAKE_BUILD_PARALLEL_LEVEL}) @@ -147,17 +147,24 @@ function(build_individual_gemm_targets datatype layout) set(num_workers 8) endif() endif() - + # Generate individual kernel files using parallel version message(STATUS "Generating individual kernels for ${datatype} ${layout} using ${num_workers} workers...") message(STATUS " Working path: ${working_path}") message(STATUS " Config file: ${json_blob}") message(STATUS " Python executable: ${Python3_EXECUTABLE}") message(STATUS " Script path: ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py") - + # Create working directory first file(MAKE_DIRECTORY ${working_path}) - + + message(STATUS "COMMAND: ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_LIST_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --list_kernels") + # First, just list the kernels (fast operation) message(STATUS " Listing kernel configurations...") execute_process( @@ -172,11 +179,11 @@ function(build_individual_gemm_targets datatype layout) OUTPUT_VARIABLE list_output ERROR_VARIABLE list_error ) - + if(NOT ret EQUAL 0) message(FATAL_ERROR "Failed to list kernels for ${datatype} ${layout}: ${list_error}") endif() - + # Read kernel count if(EXISTS ${working_path}/gemm_kernel_count.txt) file(READ ${working_path}/gemm_kernel_count.txt kernel_count) @@ -185,7 +192,7 @@ function(build_individual_gemm_targets datatype layout) else() message(FATAL_ERROR "Kernel count file not found") endif() - + # Read kernel list and create targets if(EXISTS ${working_path}/gemm_kernel_list.txt) file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) @@ -195,7 +202,7 @@ function(build_individual_gemm_targets datatype layout) list(GET parts 0 kernel_name) list(GET parts 1 tile_config) list(GET parts 2 trait_combo) - + # Create individual target create_individual_gemm_target("${datatype}" "${layout}" "${trait_combo}" "${tile_config}" "${json_blob}") endforeach() @@ -210,9 +217,9 @@ message(STATUS "GEMM_DATATYPE: ${GEMM_DATATYPE}") message(STATUS "GEMM_LAYOUT: ${GEMM_LAYOUT}") message(STATUS "SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") -# Filter GPU targets to only gfx90a, gfx942, and gfx950 +# Filter GPU targets to only gfx90a, gfx942, gfx950, gfx1201 set(GEMM_GPU_TARGETS_INDIVIDUAL "") -set(DESIRED_TARGETS "gfx90a;gfx942;gfx950") +set(DESIRED_TARGETS "gfx90a;gfx942;gfx950;gfx1201") foreach(target IN LISTS SUPPORTED_GPU_TARGETS) if(target IN_LIST DESIRED_TARGETS) @@ -223,13 +230,13 @@ endforeach() # Skip build if no matching targets found if(NOT GEMM_GPU_TARGETS_INDIVIDUAL) - message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") + message(WARNING "Skipping Tile Engine GEMM build: No supported GPU targets (gfx90a, gfx942, gfx950, gfx1201) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}") else() message(STATUS "Building individual GEMM targets for GPU targets: ${GEMM_GPU_TARGETS_INDIVIDUAL}") # Enable parallel compilation optimizations # Set up job pools for better parallel compilation control - set_property(GLOBAL PROPERTY JOB_POOLS + set_property(GLOBAL PROPERTY JOB_POOLS compile_heavy=4 # Limit heavy compilations to prevent OOM compile_normal=16 # Allow more parallel normal compilations ) diff --git a/tile_engine/ops/gemm/codegen_utils.py b/tile_engine/ops/gemm/codegen_utils.py index 6a87193043..98595933b8 100644 --- a/tile_engine/ops/gemm/codegen_utils.py +++ b/tile_engine/ops/gemm/codegen_utils.py @@ -179,6 +179,11 @@ warp_tile_supported_combinations = { [32, 32, 64], ], }, + "gfx1201": { + "fp16_fp16_fp16": [ + [16, 16, 16], + ], + }, } # To Do: remove some unsupported combinations diff --git a/tile_engine/ops/gemm/configs/gfx120x_config.json b/tile_engine/ops/gemm/configs/gfx120x_config.json new file mode 100644 index 0000000000..6c4a5d0ec0 --- /dev/null +++ b/tile_engine/ops/gemm/configs/gfx120x_config.json @@ -0,0 +1,102 @@ +{ + "problem": { + }, + "tile_config": { + "tile_m": { + "values": [ + 256, + 128, + 64 + ] + }, + "tile_n": { + "values": [ + 256, + 128, + 64 + ] + }, + "tile_k": { + "values": [ + 256, + 128, + 64 + ] + }, + "warp_m": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_n": { + "values": [ + 4, + 2, + 1 + ] + }, + "warp_k": { + "values": [ + 1 + ] + }, + "warp_tile_m": { + "values": [ + 16 + ] + }, + "warp_tile_n": { + "values": [ + 16 + ] + }, + "warp_tile_k": { + "values": [ + 16 + ] + } + }, + "trait_config": { + "pipeline": { + "values": [ + "compv3", + "mem" + ] + }, + "scheduler": { + "values": [ + "intrawave", + "interwave" + ] + }, + "epilogue": { + "values": [ + "cshuffle", + "default" + ] + }, + "pad_m": { + "values": [ + false + ] + }, + "pad_n": { + "values": [ + false + ] + }, + "pad_k": { + "values": [ + false + ] + }, + "persistent": { + "values": [ + false, + true + ] + } + } +} diff --git a/tile_engine/ops/gemm/validation_utils.py b/tile_engine/ops/gemm/validation_utils.py index 7367f2446d..c0e109bf11 100644 --- a/tile_engine/ops/gemm/validation_utils.py +++ b/tile_engine/ops/gemm/validation_utils.py @@ -103,6 +103,36 @@ WARP_TILE_SUPPORTED_COMBINATIONS = { [32, 32, 64], ], }, + "gfx1201": { + "fp16_fp16_fp16": [ + [16, 16, 16], + ], + }, +} + +# Supported warp tile combinations for different GPU architectures and data types +WARP_SUPPORTED_COMBINATIONS = { + "gfx90a": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx942": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx950": [ + [1, 4, 1], + [2, 2, 1], + [4, 1, 1], + ], + "gfx1201": [ + [2, 4, 1], + [1, 8, 1], + [8, 1, 1], + [4, 2, 1], + ], } # Unsupported trait combinations @@ -155,9 +185,32 @@ def is_trait_combination_valid(pipeline: str, epilogue: str, scheduler: str) -> return (pipeline, epilogue, scheduler) not in TRAIT_UNSUPPORTED_COMBINATIONS -def validate_warp_configuration(warp_m: int, warp_n: int, warp_k: int) -> bool: +def validate_warp_configuration( + warp_m: int, + warp_n: int, + warp_k: int, + gpu_name: str = None, +) -> bool: """Validate warp configuration.""" - return (warp_m, warp_n, warp_k) in [(1, 4, 1), (2, 2, 1), (4, 1, 1)] + if gpu_name is None: + gpu_name = get_gpu_name_by_id(0) + + current_combination = [warp_m, warp_n, warp_k] + + allowed_combinations = WARP_SUPPORTED_COMBINATIONS.get(gpu_name, {}) + if not allowed_combinations: + # If GPU not recognized, try to be permissive but log warning + logging.warning(f"No warp_[m/n/k] combinations found for GPU: {gpu_name}") + return True + + # Check if current combination is in the allowed list + if current_combination not in allowed_combinations: + error_msg = ( + f"Invalid warp tile combination: {current_combination} not in allowed list. " + ) + return False + + return True def validate_dimension_alignment(