diff --git a/CMakeLists.txt b/CMakeLists.txt index 86ed3c96ec..5f4f6a52fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -42,6 +42,8 @@ option(ENABLE_CLANG_CPP_CHECKS "Enables clang tidy, cppcheck" ON) option(MIOPEN_REQ_LIBS_ONLY "Build only the MIOpen required libraries" OFF) option(CK_EXPERIMENTAL_BUILDER "Enable experimental builder" OFF) option(BUILD_MHA_LIB "Build the static library for flash attention" OFF) +option(FORCE_DISABLE_XDL "Skip compiling XDL specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) +option(FORCE_DISABLE_WMMA "Skip compiling WMMA specific instances (even if supported GPUs are included in GPU_TARGETS)" OFF) if(CK_EXPERIMENTAL_BUILDER) add_definitions(-DCK_EXPERIMENTAL_BUILDER) @@ -232,12 +234,12 @@ message(STATUS "Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}" # Cache SUPPORTED_GPU_TARGETS for debug set(SUPPORTED_GPU_TARGETS "${SUPPORTED_GPU_TARGETS}" CACHE STRING "List of supported GPU targets") -if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx9|gfx11|gfx12" AND NOT FORCE_DISABLE_XDL) message(STATUS "Enabling XDL instances") add_definitions(-DCK_USE_XDL) set(CK_USE_XDL "ON") endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx94" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND NOT FORCE_DISABLE_XDL) message(STATUS "Enabling XDL FP8 gemms on native architectures") add_definitions(-DCK_USE_GFX94) set(CK_USE_GFX94 "ON") @@ -250,7 +252,7 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx10") add_definitions(-DCK_GFX1030_SUPPORT) endif() -if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if ((SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) set(CK_USE_WMMA "ON") @@ -260,7 +262,7 @@ endif() # define the macro with the current value (0 or 1) add_definitions(-DCK_TILE_USE_WMMA=${CK_TILE_USE_WMMA}) -if (SUPPORTED_GPU_TARGETS MATCHES "gfx12") +if (SUPPORTED_GPU_TARGETS MATCHES "gfx12" AND NOT FORCE_DISABLE_WMMA) message(STATUS "Enabling WMMA FP8 gemms on native architectures") add_definitions(-DCK_USE_WMMA_FP8) set(CK_USE_WMMA_FP8 "ON") diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index b1e3d86971..ce41c3310f 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -37,6 +37,13 @@ if(USE_BITINT_EXTENSION_INT4) add_example_dependencies(example_grouped_gemm_xdl example_grouped_gemm_xdl_int4) endif() +add_custom_target(example_grouped_gemm_wmma) +add_example_executable(example_grouped_gemm_wmma_splitk_fp16 grouped_gemm_wmma_splitk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_fp16) + +add_example_executable(example_grouped_gemm_wmma_splitk_bf16 grouped_gemm_wmma_splitk_bf16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_splitk_bf16) + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp new file mode 100644 index 0000000000..e4da397c23 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_bf16.cpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = BF16; +using BDataType = BF16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +#define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp new file mode 100644 index 0000000000..d5b2205892 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_splitk_fp16.cpp @@ -0,0 +1,71 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/utility/ignore.hpp" +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_CShuffleV3 + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +#define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index 0e64fbb7c6..764b533455 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -19,6 +19,10 @@ struct ProblemSize final std::vector stride_Cs; ck::index_t group_count; + +#if defined(EXAMPLE_USE_SPLITK) + ck::index_t k_batch; +#endif }; struct ExecutionConfig final @@ -177,6 +181,10 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co auto argument = gemm.MakeArgument( p_a, p_b, p_Ds, p_c, gemm_descs, a_element_op, b_element_op, c_element_op); +#if defined(EXAMPLE_USE_SPLITK) + gemm.SetKBatchSize(&argument, problem_size.k_batch); +#endif + std::size_t workspace_size = gemm.GetWorkSpaceSize(&argument); std::size_t kargs_size = gemm.GetDeviceKernelArgSize(&argument); std::size_t hargs_size = gemm.GetHostKernelArgSize(&argument); @@ -285,12 +293,15 @@ bool run_grouped_gemm_example(int argc, char* argv[]) ExecutionConfig config; problem_size.group_count = 16; +#if defined(EXAMPLE_USE_SPLITK) + problem_size.k_batch = 1; +#endif if(argc == 1) { // use default cases } - else if(argc == 4 || argc == 6) + else if(argc == 4 || argc == 6 || argc == 7) { config.do_verification = std::stoi(argv[1]); config.init_method = std::stoi(argv[2]); @@ -300,6 +311,13 @@ bool run_grouped_gemm_example(int argc, char* argv[]) config.async_hargs = std::stoi(argv[4]); problem_size.group_count = std::stoi(argv[5]); } + +#if defined(EXAMPLE_USE_SPLITK) + if(argc == 7) + { + problem_size.k_batch = std::stoi(argv[6]); + } +#endif } else { @@ -307,7 +325,10 @@ bool run_grouped_gemm_example(int argc, char* argv[]) printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); printf("arg3: time kernel (0=n0, 1=yes)\n"); printf("arg4: async hargs (0=n0, 1=yes)\n"); - printf("arg5: group count (default=16)"); + printf("arg5: group count (default=16)\n"); +#if defined(EXAMPLE_USE_SPLITK) + printf("arg6: k-batch count (default=1)\n"); +#endif exit(1); } diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp index 8d3fd146bc..0134465347 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm.cpp @@ -158,7 +158,7 @@ auto create_args(int argc, char* argv[]) .insert("stride_c", "0", "Tensor C stride") .insert("v", "1", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") .insert( - "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp6xfp6, fp8xfp8") + "mx_prec", "fp4xfp4", "data type for activation and weight, support: fp4xfp4, fp8xfp8") .insert("warmup", "50", "number of iterations before benchmark the kernel") .insert("repeat", "100", "number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index 9c12509d59..f177ef04ca 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -75,7 +75,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, HasHotLoop, TailNum>; - using MXFlatmmPipeline = ck_tile::MXF4FlatmmPipelineAGmemBGmemCRegV1; + using MXFlatmmPipeline = ck_tile::MXFlatmmPipelineAGmemBGmemCRegV1; using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner.cpp| GemmConfigQuantDecode (or) GemmConfigBQuantPrefill | +| For selecting BQuant | bquant | gemm_bquant_quantgrouped_.cpp| GemmConfigQuantDecode (or) GemmConfigQuantPrefill | | For selecting BQuant with Preshuffle quant | bquant | gemm_bquant_quantgrouped_preshufflequant.cpp| GemmConfigPreshuffleQuantDecode (or) GemmConfigPreshuffleBQuantPrefill | | For selecting PreShuffle B with BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb.cpp| GemmConfigPreshuffleB_BQuant_Decode (or) GemmConfigPreshuffleB_BQuant_Prefill | For selecting PreShuffle B with preshuffle BQuant | bquant | gemm_bquant_quantgrouped_preshuffleb_preshufflequant.cpp |GemmConfigPreshuffleB_PreshuffleBQuant_Decode (or) GemmConfigPreshuffleB_PreshuffleBQuant_Prefill diff --git a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp index 0f75976602..ad1a4e0d10 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_aquant_quantgrouped.cpp @@ -6,6 +6,10 @@ template using GemmConfig = GemmConfigQuantDecode; +// GemmConfigQuantPrefill is also supported for aquant grouped quantization +// template +// using GemmConfig = GemmConfigQuantPrefill; + void aquant_quantgrouped_instance_factory( std::unordered_map>& lut) { diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index 2dbae9e42c..61fd65960f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index 40cf88624b..1d471068eb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index 5c21d5aa16..280029033b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 80b9a2765e..a277c864bb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigBQuantPrefill; +using GemmConfig = GemmConfigQuantPrefill; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 81032d6452..116661c157 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -221,7 +221,7 @@ struct GemmConfigPreshuffleB_PreshuffleBQuant_Prefill }; template -struct GemmConfigBQuantPrefill : public GemmConfigBase +struct GemmConfigQuantPrefill : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -237,13 +237,13 @@ struct GemmConfigBQuantPrefill : public GemmConfigBase }; template -struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigBQuantPrefill +struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { static constexpr bool PreshuffleQuant = true; }; template -struct GemmConfigBQuantPrefill_Wmma : public GemmConfigBQuantPrefill +struct GemmConfigBQuantPrefill_Wmma : public GemmConfigQuantPrefill { static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; diff --git a/experimental/builder/README.md b/experimental/builder/README.md index 141a34b9f9..d3f0ec40f9 100644 --- a/experimental/builder/README.md +++ b/experimental/builder/README.md @@ -35,27 +35,41 @@ cmake .. ``` -## Building and testing +## Building and Testing -During development, all CK Builder tests can be built with command +The builder test suite is organized into two main categories: + +### Smoke Tests (Fast Unit Tests) +Quick unit tests that verify the builder's internal logic without compiling GPU kernels. These complete in under 1 second total and are suitable for frequent execution during development. ```sh -ninja test_ckb_all +ninja smoke-builder ``` -To execute all tests, run +### Regression Tests (Integration Tests) +Integration tests that compile actual GPU kernels to verify that the builder generates valid, compilable code. These are more expensive than smoke tests (can take minutes to compile) but cover more fuctionality. ```sh -ls bin/test_ckb_* | xargs -n1 sh -c +ninja regression-builder ``` -Some tests involve building old CK convolution factories, which will take a long time. -Hence, one might want to build only single test targets. For example +### Running All Tests +To build and run the complete test suite: + +```sh +ninja check-builder +``` + +### Building Individual Tests +To build and run a specific test: ```sh ninja test_ckb_conv_builder && bin/test_ckb_conv_builder ``` -When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`. -This allows us to filter out the CK Builder tests from the set full CK repository tests. -Also, the `test_ckb_all` target that builds all CK Builder tests relies on having the `test_ckb` prefix on the CMake build targets. +### Test Organization +- **Smoke tests**: Fast feedback during active development +- **Regression tests**: Thorough validation before submitting changes +- **Factory tests**: Expensive tests that build all MIOpen kernels (included in regression tests) + +When adding new tests, please follow the convention where the CMake build target starts with a prefix `test_ckb`. This allows filtering of CK Builder tests from the full CK repository test suite. diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp index 4682f636eb..375e465721 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_description.hpp @@ -98,7 +98,7 @@ struct ConvDescription f.writeLine(2, "Weights elementwise operation: ", signature.weight_element_op); f.writeLast(2, "Output elementwise operation: ", signature.output_element_op); - f.writeLine(1, "Algorithm"); + f.writeLast(1, "Algorithm"); // Compute Block section f.writeLine(2, "Thread block size: ", algorithm.thread_block_size); f.writeLine(2, @@ -123,7 +123,7 @@ struct ConvDescription algorithm.warp_gemm.n_iter); // Memory Access section - f.writeLine(2, "Memory access:"); + f.writeLast(2, "Memory access:"); f.writeLine(3, "A Tile transfer: "); f.writeLine(4, @@ -219,8 +219,6 @@ struct ConvDescription f.writeLast(4, "Vector access (GMEM write) instruction size: ", algorithm.c_tile_transfer.scalar_per_vector); - f.writeLast(2); - f.writeLast(1); return f.getString(); } diff --git a/experimental/builder/test/CMakeLists.txt b/experimental/builder/test/CMakeLists.txt index af83b430ea..e8326d0530 100644 --- a/experimental/builder/test/CMakeLists.txt +++ b/experimental/builder/test/CMakeLists.txt @@ -1,9 +1,48 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT +################################################################################ +# CK Builder Test Suite +################################################################################ +# +# This file defines the test suite for the Composable Kernel (CK) Builder, +# which is responsible for generating optimized GPU kernels for convolution +# operations. +# +# TESTING PHILOSOPHY: +# ------------------- +# Tests are organized into two main categories: +# +# 1. SMOKE TESTS (fast, < 1 second total) +# - Unit tests that verify the builder's internal logic +# - Do NOT compile GPU kernels (fast compilation) +# - Run these frequently during development for quick feedback +# - Target: `ninja smoke-builder` +# +# 2. REGRESSION TESTS (slower, may take minutes) +# - Integration tests that compile and verify actual GPU kernels +# - Ensure the builder generates valid, compilable code +# - Include expensive "factory tests" that build all MIOpen kernels +# - Run these before submitting changes +# - Target: `ninja regression-builder` +# +# QUICK START: +# ------------ +# - During development: ninja smoke-builder +# - Before submitting: ninja regression-builder +# - Run everything: ninja check-builder +# - Build specific test: ninja test_ckb_conv_builder && bin/test_ckb_conv_builder +# +################################################################################ + include(gtest) +################################################################################ +# Helper Functions +################################################################################ + # Helper function to create a gtest executable with common properties +# All builder tests share the same compilation settings and dependencies function(add_ck_builder_test test_name) add_executable(${test_name} ${ARGN} testing_utils.cpp) target_compile_features(${test_name} PRIVATE cxx_std_20) @@ -19,17 +58,51 @@ function(add_ck_builder_test test_name) target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock) endfunction() -# The test_ckb_conv_builder target has all the unit tests (each test should run < 10 ms) +# Factory tests attempt to build all the kernels needed by MIOpen. +# These are only for regression testing and development; the builds are too +# expensive for regular use in CI. +function(add_ck_factory_test test_name) + add_ck_builder_test(${test_name} ${ARGN}) + target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations) +endfunction() + +################################################################################ +# SMOKE TESTS - Fast Unit Tests (No Kernel Compilation) +################################################################################ +# These tests verify the builder's internal logic without compiling GPU kernels. +# They should complete in under 10ms each and are suitable for frequent execution +# during development. add_ck_builder_test(test_ckb_conv_builder test_conv_builder.cpp test_fwd_instance_traits.cpp test_bwd_weight_instance_traits.cpp test_bwd_data_instance_traits.cpp - test_instance_traits_util.cpp) + test_instance_traits_util.cpp +) + + # Tests the inline diff utility used for comparing strings in tests assertions + add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) -add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp) + # Tests convolution trait selection and configuration + add_ck_builder_test(test_ckb_conv_traits + conv/test_conv_traits.cpp) + + # Tests convolution problem description and parameter handling + add_ck_builder_test(test_ckb_conv_description + test_conv_description.cpp) + +################################################################################ +# REGRESSION TESTS - Integration Tests (With Kernel Compilation) +################################################################################ +# These tests compile actual GPU kernels to verify the builder generates valid, +# compilable code. They are more expensive but catch real-world issues. -# Testing the virtual GetInstanceString methods requires kernel compilation. + +# Verifies that GetInstanceString() methods produce valid kernel code. +# Tests various convolution types: +# - Group convolution (v3, standard, large tensor, WMMA, DL variants) +# - Backward weight group convolution (XDL) +# Requires kernel compilation to validate the generated strings. add_ck_builder_test(test_ckb_get_instance_string test_get_instance_string_fwd_grp_conv_v3.cpp test_get_instance_string_fwd_grp_conv.cpp @@ -38,8 +111,8 @@ add_ck_builder_test(test_ckb_get_instance_string test_get_instance_string_fwd_grp_conv_dl.cpp test_get_instance_string_bwd_weight_grp_conv_xdl.cpp) -# Testing the fwd convolution builder requires kernel compilation. -# To enable parallel compilation, the individual tests are split into separate files. +# Tests the forward convolution builder across multiple data types and dimensions. +# Individual tests are split into separate files to enable parallel compilation. add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_1d_fp16.cpp conv/test_ckb_conv_fwd_1d_bf16.cpp @@ -55,15 +128,21 @@ add_ck_builder_test(test_ckb_build_fwd_instances conv/test_ckb_conv_fwd_3d_fp32.cpp ) -# Factory tests attempt to build all the kernels need by MIOpen. -# This is only for regression testing and development, the builds are too expensive for regular use in CI. -function(add_ck_factory_test test_name) - add_ck_builder_test(${test_name} ${ARGN}) - target_link_libraries(${test_name} PRIVATE composablekernels::device_conv_operations) -endfunction() -# TODO: add these tests back in once we have CI working across all GPU architectures. +################################################################################ +# FACTORY TESTS - Expensive Regression Tests (Full MIOpen Kernel Set) +################################################################################ +# These tests attempt to build ALL kernels needed by MIOpen for various +# convolution operations. They are extremely expensive (minutes to compile) +# and are intended for deep regression testing and development only. +# NOT suitable for regular CI runs. +# +# Many tests are commented out pending CI support across all GPU architectures. + +# Tests the testing utilities themselves add_ck_factory_test(test_ckb_testing_utils test_testing_utils.cpp) + +# TODO: Re-enable these tests once we have CI working across all GPU architectures. # add_ck_factory_test(test_ckb_factory_grouped_convolution_forward test_ck_factory_grouped_convolution_forward.cpp) # add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_clamp test_ck_factory_grouped_convolution_forward_clamp.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_convscale test_ck_factory_grouped_convolution_forward_convscale.cpp) @@ -75,22 +154,30 @@ add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_ab tes add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ck_factory_grouped_convolution_forward_scaleadd_scaleadd_relu.cpp) add_ck_factory_test(test_ckb_factory_grouped_convolution_forward_dynamic_op test_ck_factory_grouped_convolution_forward_dynamic_op.cpp) -add_ck_builder_test(test_ckb_conv_traits - conv/test_conv_traits.cpp) +################################################################################ +# CTest Integration - Register Tests and Assign Labels +################################################################################ +# Tests are registered with CTest and labeled for selective execution: +# - BUILDER_SMOKE: Fast unit tests for frequent development cycles +# - BUILDER_REGRESSION: Slower integration tests for pre-submission validation -add_ck_builder_test(test_ckb_conv_description - test_conv_description.cpp) - -# Register tests with CTest and assign labels include(CTest) -# Smoke test: fast-compiling unit test -add_test(NAME test_ckb_conv_builder COMMAND test_ckb_conv_builder) -set_tests_properties(test_ckb_conv_builder PROPERTIES LABELS "BUILDER_SMOKE") - -# Regression tests: all other tests that require kernel compilation -set(CKB_REGRESSION_TESTS +# Register all smoke tests (fast unit tests, no kernel compilation) +set(CKB_SMOKE_TESTS + test_ckb_conv_builder test_ckb_inline_diff + test_ckb_conv_traits + test_ckb_conv_description +) + +foreach(test_target ${CKB_SMOKE_TESTS}) + add_test(NAME ${test_target} COMMAND ${test_target}) + set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_SMOKE") +endforeach() + +# Register all regression tests (integration tests with kernel compilation) +set(CKB_REGRESSION_TESTS test_ckb_get_instance_string test_ckb_build_fwd_instances test_ckb_testing_utils @@ -98,8 +185,6 @@ set(CKB_REGRESSION_TESTS test_ckb_factory_grouped_convolution_forward_scaleadd_ab test_ckb_factory_grouped_convolution_forward_scaleadd_scaleadd_relu test_ckb_factory_grouped_convolution_forward_dynamic_op - test_ckb_conv_traits - test_ckb_conv_description ) foreach(test_target ${CKB_REGRESSION_TESTS}) @@ -107,18 +192,31 @@ foreach(test_target ${CKB_REGRESSION_TESTS}) set_tests_properties(${test_target} PROPERTIES LABELS "BUILDER_REGRESSION") endforeach() -# Helper target to build all regression tests +################################################################################ +# Custom Build Targets - Convenient Test Execution +################################################################################ +# These targets provide convenient ways to build and run different test suites: +# - smoke-builder: Quick sanity check during development +# - regression-builder: Thorough validation before submitting changes +# - check-builder: Complete test suite execution + +# Helper target to build all smoke tests (without running them) +add_custom_target(build-smoke-builder DEPENDS ${CKB_SMOKE_TESTS}) + +# Helper target to build all regression tests (without running them) add_custom_target(build-regression-builder DEPENDS ${CKB_REGRESSION_TESTS}) -# Target to run only smoke tests (builds only test_ckb_conv_builder) +# Target to run only smoke tests (builds and runs all smoke test executables) +# Use this for quick feedback during active development add_custom_target(smoke-builder COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_SMOKE" - DEPENDS test_ckb_conv_builder + DEPENDS build-smoke-builder USES_TERMINAL COMMENT "Running experimental builder smoke tests..." ) -# Target to run only regression tests (builds all regression test executables) +# Target to run only regression tests (builds and runs all regression test executables) +# Use this before submitting changes to catch integration issues add_custom_target(regression-builder COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -L "BUILDER_REGRESSION" DEPENDS build-regression-builder @@ -126,15 +224,20 @@ add_custom_target(regression-builder COMMENT "Running experimental builder regression tests..." ) -# Target to run all builder tests (builds all test executables) +# Target to run all builder tests (builds and runs all test executables) +# Use this for comprehensive validation add_custom_target(check-builder COMMAND ${CMAKE_CTEST_COMMAND} --output-on-failure -C ${CMAKE_CFG_INTDIR} -R "^test_ckb" - DEPENDS test_ckb_conv_builder build-regression-builder + DEPENDS build-smoke-builder build-regression-builder USES_TERMINAL COMMENT "Running all experimental builder tests..." ) -# Print summary of test organization +################################################################################ +# Build Summary +################################################################################ + +# Print summary of test organization for developer reference message(STATUS "CK Builder test organization:") -message(STATUS " Smoke test: test_ckb_conv_builder") +message(STATUS " Smoke tests: ${CKB_SMOKE_TESTS}") message(STATUS " Regression tests: ${CKB_REGRESSION_TESTS}") diff --git a/experimental/builder/test/test_conv_description.cpp b/experimental/builder/test/test_conv_description.cpp index b83abe9f43..933995730a 100644 --- a/experimental/builder/test/test_conv_description.cpp +++ b/experimental/builder/test/test_conv_description.cpp @@ -127,41 +127,39 @@ TEST(ConvDescriptionTest, DefaultInstanceHasDetailedDescription) "│ ├─ Input elementwise operation: PASS_THROUGH\n" "│ ├─ Weights elementwise operation: PASS_THROUGH\n" "│ └─ Output elementwise operation: PASS_THROUGH\n" - "├─ Algorithm\n" - "│ ├─ Thread block size: 256\n" - "│ ├─ Data tile size: 256×256×32\n" - "│ ├─ Gemm padding: DEFAULT\n" - "│ ├─ Convolution specialization: DEFAULT\n" - "│ ├─ Pipeline version: V4\n" - "│ ├─ Pipeline scheduler: INTRAWAVE\n" - "│ ├─ Warp Gemm parameters: \n" - "│ │ ├─ subtile size: 16×16\n" - "│ │ └─ Number of warp gemm iterations: 4×4\n" - "│ ├─ Memory access:\n" - "│ │ ├─ A Tile transfer: \n" - "│ │ │ ├─ Tile dimensions: 4×256×8×\n" - "│ │ │ ├─ The innermost K subdimension size: 8\n" - "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" - "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" - "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - "│ │ ├─ B Tile transfer: \n" - "│ │ │ ├─ Tile dimensions: 4×256×8×\n" - "│ │ │ ├─ The innermost K subdimension size: 8\n" - "│ │ │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" - "│ │ │ ├─ The order of accessing data tile axes: 0×1×2\n" - "│ │ │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" - "│ │ │ ├─ Vector access (GMEM read) instruction size: 8\n" - "│ │ │ ├─ Vector access (LDS write) instruction size: 8\n" - "│ │ │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" - "│ │ └─ C Tile transfer: \n" - "│ │ ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" - "│ │ ├─ Spatial thread distribution used to store data: 1×32×1×8\n" - "│ │ └─ Vector access (GMEM write) instruction size: 8\n" - "│ └─ \n" - "└─ ")); + "└─ Algorithm\n" + " ├─ Thread block size: 256\n" + " ├─ Data tile size: 256×256×32\n" + " ├─ Gemm padding: DEFAULT\n" + " ├─ Convolution specialization: DEFAULT\n" + " ├─ Pipeline version: V4\n" + " ├─ Pipeline scheduler: INTRAWAVE\n" + " ├─ Warp Gemm parameters: \n" + " │ ├─ subtile size: 16×16\n" + " │ └─ Number of warp gemm iterations: 4×4\n" + " └─ Memory access:\n" + " ├─ A Tile transfer: \n" + " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " ├─ B Tile transfer: \n" + " │ ├─ Tile dimensions: 4×256×8×\n" + " │ ├─ The innermost K subdimension size: 8\n" + " │ ├─ Spatial thread distribution over the data tile: 0×1×2\n" + " │ ├─ The order of accessing data tile axes: 0×1×2\n" + " │ ├─ Vectorized memory access axis index (with contiguous memory): 2\n" + " │ ├─ Vector access (GMEM read) instruction size: 8\n" + " │ ├─ Vector access (LDS write) instruction size: 8\n" + " │ └─ LDS data layout padding (to prevent bank conflicts): 8\n" + " └─ C Tile transfer: \n" + " ├─ Data shuffle (number of gemm instructions per iteration): 1×1\n" + " ├─ Spatial thread distribution used to store data: 1×32×1×8\n" + " └─ Vector access (GMEM write) instruction size: 8")); } // NOTE: BackwardDataInstanceHasDetailedDescription test is disabled because ConvFactory diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index a42f7170aa..ec623db6f7 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -199,7 +199,7 @@ struct BaseArgument BaseArgument(const BaseArgument&) = default; BaseArgument& operator=(const BaseArgument&) = default; - virtual ~BaseArgument() {} + virtual __host__ __device__ ~BaseArgument() {} void* p_workspace_ = nullptr; }; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp new file mode 100644 index 0000000000..2f0c047167 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp @@ -0,0 +1,827 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_splitk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + // Binary search lookup to find which group this block is part of + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) + { + if(block_id < gemm_desc_ptr[group_id].block_start_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index + // and thus needs a non-const reference. It's also not feasible to store this in global + // memory as different threads would be writing different K values to the same arg struct + auto karg = gemm_desc_ptr[group_id].karg_; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; + + // Tile index first dimension is the K batch + auto tile_index = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(static_cast(p_shared), + splitk_batch_offset, + karg, + block_2_ctile_map, + epilogue_args); +#if defined(__gfx11__) + } +#endif +#else + ignore = gemm_descs_const; + ignore = group_count; +#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) +} + +template +struct DeviceGroupedGemm_Wmma_CShuffleV3 : public DeviceGroupedGemmSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(KPerBlock % AK1 == 0); + static constexpr index_t K0PerBlock = KPerBlock / AK1; + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by DeviceBatchedGemm base class. + false>; // PermuteB not supported by DeviceBatchedGemm base class. + + using CGridDesc_M_N = + remove_cvref_t( + 1, 1, 1, 1, 1))>; + using Block2ETileMapKSplit = + BlockToCTileMap_KSplit_M00_N0_M01Adapt; + // Block2CTileMap configuration parameter. + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; + using KernelArgument = typename GridwiseGemm::Argument; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; + template + struct GemmTransKernelArgBase + { + KernelArgument_ karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } + }; + using GemmTransKernelArg = GemmTransKernelArgBase; + + static constexpr index_t DefaultKBatch = 1; + + static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) + { + index_t k_grain = karg.KBatch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / karg.KBatch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + // Argument + // TODO: Add A/B/CDE element op? + struct Argument : public BaseArgument + { + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs) + : Argument(p_As, p_Bs, p_Es, gemm_descs, DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector& p_Es, + std::vector& gemm_descs, + index_t kbatch) + : K_BATCH{kbatch}, gemm_kernel_host_args_{nullptr} + { + grid_size_ = 0; + group_count_ = ck::type_convert(gemm_descs.size()); + + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/c.size"); + } + + gemm_kernel_args_.reserve(group_count_); + + skipped_group_count_ = 0; + + for(std::size_t i = 0; i < gemm_descs.size(); ++i) + { + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(M == 0) + { + skipped_group_count_++; + continue; + } + + const index_t stride_a = gemm_descs[i].stride_A_; + const index_t stride_b = gemm_descs[i].stride_B_; + const index_t stride_c = gemm_descs[i].stride_C_; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + M, m_padded, N, n_padded, stride_c); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument(std::array{p_As[i]}, + std::array{p_Bs[i]}, + std::array{}, // p_ds_grid_ + type_convert(p_Es[i]), + M, + N, + K, + std::array{stride_a}, + std::array{stride_b}, + std::array{}, // StrideDs_ + stride_c, + K_BATCH, + PassThrough{}, + PassThrough{}, + PassThrough{}, + false); + + gemm_kernel_args_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + } + } + + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] kbatch The new splitK parameter value. + */ + void UpdateKBatch(index_t kbatch) + { + K_BATCH = kbatch; + grid_size_ = 0; + + for(std::size_t i = 0; i < gemm_kernel_args_.size(); ++i) + { + auto& karg = gemm_kernel_args_[i].karg_; + + const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, K_BATCH); + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, K_BATCH); + const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, K_BATCH); + const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, K_BATCH); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE); + + const auto local_b2c_tile_map = + Block2ETileMapKSplit{c_grid_desc_m_n, B2E_M01, K_BATCH}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + // block-to-e-tile map + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KRead = k_read; + karg.KPadded = k_padded; + karg.AK0 = ak0_padded; + karg.BK0 = bk0_padded; + karg.KBatch = K_BATCH; + gemm_kernel_args_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_kernel_args_[i].block_start_ = block_start; + gemm_kernel_args_[i].block_end_ = block_end; + } + } + + // private: + index_t K_BATCH; + index_t group_count_; + index_t skipped_group_count_; + + std::vector gemm_kernel_args_; + void* gemm_kernel_host_args_; + index_t grid_size_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) + { + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + + bool all_have_kbatch_gt_one = arg.gemm_kernel_args_[0].karg_.KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_kernel_args_[0].karg_); + + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; + + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& karg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.KBatch; + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); + not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); + } + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) + { + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + hip_check_error(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + cpy_stream)); + hip_check_error(hipEventRecord(cpy_event, cpy_stream)); + hip_check_error(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. + { + + hip_check_error( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_args_.data(), + arg.gemm_kernel_args_.size() * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) + { + for(const auto& trans_arg : arg.gemm_kernel_args_) + { + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_e_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } + } + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_kernel_args_.size()); + }; + + // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) + { + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + } + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(all_have_kbatch_gt_one) + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_splitk; + + Run(kernel); + } + } + } + + return ave_time; + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static constexpr bool IsValidCompilationParameter() + { + // TODO: properly implement this check + return true; + } + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.K_BATCH > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((ck::type_convert(arg.gemm_kernel_args_.size()) + + arg.skipped_group_count_) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_kernel_args_.size(); ++i) + { + const auto& a = arg.gemm_kernel_args_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) + { + return Argument{p_As, p_Bs, p_Es, gemm_descs}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>&, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation) override + { + return std::make_unique(p_As, p_Bs, p_Es, gemm_descs); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemm_WmmaSplitK" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + return p_arg_->gemm_kernel_args_.size() * sizeof(GemmTransKernelArg); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + return GetWorkSpaceSize(p_arg); + } + + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } + + // TODO: deperecation notice. + static void SetKBatchSize(Argument& arg, index_t kbatch) { arg.UpdateKBatch(kbatch); } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const override + { + return this->SetWorkSpacePointer(p_arg, p_dev_kernel_args); + } + + //---------------------------------------------------------------------------------------------- + /// @brief Sets the host kernel arguments pointer and copies that data on the host side. + /// This function can be utilised to use pinned memory for the host args and + /// achieve fully async data copy. + /// + /// @param p_arg The pointer to the Argument we're going to update. + /// @param[in] p_host_kernel_args The pointer to the host memory where the kernel + /// arguments will be copied + /// + void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const + { + Argument* pArg_ = dynamic_cast(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_kernel_args_.begin(), + pArg_->gemm_kernel_args_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 2f3555e33e..6629be2511 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -470,9 +470,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 DsGridPointer p_ds_grid; EDataType* p_e_grid; - const AElementwiseOperation a_element_op; - const BElementwiseOperation b_element_op; - const CDEElementwiseOperation cde_element_op; + AElementwiseOperation a_element_op; + BElementwiseOperation b_element_op; + CDEElementwiseOperation cde_element_op; // TODO: it can be used with SplitK+reduction but currently only used with SplitK+atomicAdd bool is_reduce; @@ -555,13 +555,17 @@ struct GridwiseGemm_wmma_cshuffle_v3 template + typename Block2CTileMap, + typename EpilogueArgument, + int BlockMapMBlockIndex = 0, + int BlockMapNBlockIndex = 1> __device__ static void Run(AsGridPointer& p_as_grid, BsGridPointer& p_bs_grid, DsGridPointer& p_ds_grid, EDataType* p_e_grid, void* p_shared, const Problem& problem, + const Block2CTileMap& block_2_ctile_map, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, @@ -582,9 +586,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); - // divide block work by [M, N] - const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); @@ -596,8 +597,10 @@ struct GridwiseGemm_wmma_cshuffle_v3 return; } - const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); - const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + const index_t block_m_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); + const index_t block_n_id = + __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); // BScale struct (Empty) using BScale = typename BlockwiseGemmPipe::Empty; @@ -632,15 +635,51 @@ struct GridwiseGemm_wmma_cshuffle_v3 epilogue_args); } + template + __device__ static void Run(AsGridPointer& p_as_grid, + BsGridPointer& p_bs_grid, + DsGridPointer& p_ds_grid, + EDataType* p_e_grid, + void* p_shared, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + EpilogueArgument& epilogue_args) + { + Run(p_as_grid, + p_bs_grid, + p_ds_grid, + p_e_grid, + p_shared, + problem, + DefaultBlock2CTileMap(problem), + a_element_op, + b_element_op, + cde_element_op, + epilogue_args); + } + // Wrapper function to have __global__ function in common // between gemm_universal, b_scale, ab_scale, etc. template + typename Block2CTileMap, + typename EpilogueArgument, + int BlockMapMBlockIndex = 0, + int BlockMapNBlockIndex = 1> __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg, + const Block2CTileMap& block_2_ctile_map, EpilogueArgument& epilogue_args) { // shift A matrices pointer for splitk @@ -659,17 +698,47 @@ struct GridwiseGemm_wmma_cshuffle_v3 splitk_batch_offset.b_k_split_offset[i]; }); - Run( - p_as_grid_splitk, - p_bs_grid_splitk, - karg.p_ds_grid, - karg.p_e_grid + splitk_batch_offset.c_reduce_offset, - p_shared, - karg, - karg.a_element_op, - karg.b_element_op, - karg.cde_element_op, - epilogue_args); + Run(p_as_grid_splitk, + p_bs_grid_splitk, + karg.p_ds_grid, + karg.p_e_grid + splitk_batch_offset.c_reduce_offset, + p_shared, + karg, + block_2_ctile_map, + karg.a_element_op, + karg.b_element_op, + karg.cde_element_op, + epilogue_args); + } + + // Wrapper function to have __global__ function in common + // between gemm_universal, b_scale, ab_scale, etc. + template + __device__ static void Run(void* p_shared, + const SplitKBatchOffset& splitk_batch_offset, + Argument& karg, + EpilogueArgument& epilogue_args) + { + Run( + p_shared, splitk_batch_offset, karg, DefaultBlock2CTileMap(karg), epilogue_args); + } + + __device__ static auto DefaultBlock2CTileMap(const Problem& problem) + { + return Block2CTileMap{problem.M, problem.N, 4}; } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 80103cfdb3..4de3a35b3e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -729,6 +729,13 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec; if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K) { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Arg K value too low for combination of AK1/BK1/KBatch. AK1: " + << AK1Number << ", BK1: " << BK1Number << ", KBatch: " << karg.KBatch + << ", K: " << karg.K << " " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } return false; } } diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index 91c0779ab6..e80267faec 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -293,6 +293,15 @@ struct tile_window_with_static_distribution 0, dst_tensor, number{}, bool_constant{}); } + template + CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const + { + constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{}); + const auto bottom_tensor_coord_off = make_tensor_coordinate( + this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); + return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset()); + } + template ) return offset_t::value; else - { - auto bottom_tensor_idx_off = to_multi_index(offset_t{}); - auto bottom_tensor_coord_off = make_tensor_coordinate( - this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off); - return bottom_tensor_coord_off.get_offset(); - } + return get_load_offset(offset_t{}); }(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { diff --git a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp index e5c666de46..ff799cb0fc 100644 --- a/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp @@ -46,8 +46,8 @@ struct MXFlatmmPipelineProblem : FlatmmPipelineProblem -struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 +template +struct MXFlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1 { using Underlying = FlatmmPipelineAGmemBGmemCRegV1; @@ -470,17 +470,39 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(); } + template + CK_TILE_DEVICE auto operator()(Args&&... args) const + { + auto c_warp_tensors = Run_(std::forward(args)...); + + // Block GEMM Acc register tile + using CWarpDstr = typename WG::CWarpDstr; + constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + auto c_block_tile = BlockFlatmm{}.MakeCBlockTile(); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + c_block_tile.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensors(mIter)(nIter).get_thread_buffer()); + }); + }); + return c_block_tile; + } + template - CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_copy_dram_window_tmp, - const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, - const ScaleADramBlockWindowTmp& scale_a_window, - const ScaleBDramBlockWindowTmp& scale_b_window, - index_t num_loop, - void* __restrict__ p_smem_ping, - void* __restrict__ p_smem_pong) const + CK_TILE_DEVICE auto Run_(const ADramBlockWindowTmp& a_copy_dram_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + const ScaleADramBlockWindowTmp& scale_a_window, + const ScaleBDramBlockWindowTmp& scale_b_window, + index_t num_loop, + void* __restrict__ p_smem_ping, + void* __restrict__ p_smem_pong) const { #ifndef __gfx950__ static_assert(false, "Only gfx950 is supported for MXFP4 flatmm pipeline now."); @@ -497,19 +519,14 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; - auto a_dram_window = - make_tile_window(PipelinePolicy::template MakeMXFP4_AAsyncLoadDramDescriptor( + make_tile_window(PipelinePolicy::template MakeMX_AAsyncLoadDramDescriptor( a_copy_dram_window_tmp.get_bottom_tensor_view()), a_copy_dram_window_tmp.get_window_lengths(), a_copy_dram_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ADramTileDistribution()); + PipelinePolicy::template MakeMX_ADramTileDistribution()); __builtin_amdgcn_sched_barrier(0); @@ -518,7 +535,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1(p_smem_pong); constexpr auto a_lds_block_desc = - PipelinePolicy::template MakeMXFP4_ALdsBlockDescriptor(); + PipelinePolicy::template MakeMX_ALdsBlockDescriptor(); auto a_lds_block_ping = make_tensor_view(p_a_lds_ping, a_lds_block_desc); @@ -535,39 +552,34 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number{}), {0, 0}, - PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); + PipelinePolicy::template MakeMX_ALDS_TileDistribution()); auto a_warp_window_pong = make_tile_window(a_lds_block_pong, make_tuple(number{}, number{}), {0, 0}, - PipelinePolicy::template MakeMXF4_ALDS_TileDistribution()); - - // Block GEMM - auto block_flatmm = BlockFlatmm(); - // Acc register tile - auto c_block_tile = block_flatmm.MakeCBlockTile(); + PipelinePolicy::template MakeMX_ALDS_TileDistribution()); // B flat DRAM window for load // pingpong buffer for B - auto b_flat_dram_windows = generate_tuple( + auto b_flat_dram_window = + make_tile_window(b_flat_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + PipelinePolicy::template MakeMX_BFlatDramTileDistribution()); + auto b_flat_dram_offsets = generate_tuple( [&](auto nIter) { constexpr auto packed_n_idx = nIter / number{}; constexpr auto packed_n_rank = nIter % number{}; - auto window_i = make_tile_window( - b_flat_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), - b_flat_dram_block_window_tmp.get_window_origin(), - PipelinePolicy::template MakeMXFP4_BFlatDramTileDistribution()); - move_tile_window( - window_i, - {number{}, - number<0>{}}); - return window_i; + return b_flat_dram_window.get_load_offset( + tuple, + number<0>>{}) + + b_flat_dram_window.get_load_offset( + tuple, number<0>>{}); }, number{}); statically_indexed_array< - statically_indexed_array, + statically_indexed_array, NIterPerWarp> b_warp_tensor_ping, b_warp_tensor_pong; @@ -576,41 +588,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}, number<64 / WG::kM>{}), scale_a_window.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ScaleA_FlatDramTileDistribution()); + PipelinePolicy::template MakeMX_ScaleA_FlatDramTileDistribution()); + const auto scale_a_dram_step_m = amd_wave_read_first_lane( + scale_a_dram_window.get_load_offset(tuple, number<0>>{})); + const auto scale_a_dram_step_k = amd_wave_read_first_lane( + scale_a_dram_window.get_load_offset(tuple, number<64 / WG::kM>>{})); auto scale_b_dram_window = make_tile_window( scale_b_window.get_bottom_tensor_view(), make_tuple(number{}, number<64 / WG::kN>{}), scale_b_window.get_window_origin(), - PipelinePolicy::template MakeMXFP4_ScaleB_DramTileDistribution()); + PipelinePolicy::template MakeMX_ScaleB_DramTileDistribution()); + const auto scale_b_dram_step_n = amd_wave_read_first_lane( + scale_b_dram_window.get_load_offset(tuple, number<0>>{})); + const auto scale_b_dram_step_k = amd_wave_read_first_lane( + scale_b_dram_window.get_load_offset(tuple, number<64 / WG::kN>>{})); + + constexpr index_t MPackIterPerWarp = MIterPerWarp / MXdlPack; + constexpr index_t NPackIterPerWarp = NIterPerWarp / NXdlPack; + constexpr index_t KPackIterPerWarp = KIterPerWarp / KXdlPack; // ping pong buffer for scale A statically_indexed_array< - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_dram_windows; - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_tile_tensor_ping; - statically_indexed_array, - MIterPerWarp / MXdlPack> - scale_a_tile_tensor_pong; + statically_indexed_array, + MPackIterPerWarp> + scale_a_tile_tensor_ping, scale_a_tile_tensor_pong; // ping pong buffer for scale B statically_indexed_array< - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_dram_windows; - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_tile_tensor_ping; - statically_indexed_array, - NIterPerWarp / NXdlPack> - scale_b_tile_tensor_pong; + statically_indexed_array, + NPackIterPerWarp> + scale_b_tile_tensor_ping, scale_b_tile_tensor_pong; auto async_load_tile_ = [](auto lds, auto dram) { async_load_tile(lds, dram, number<-1>{}, true_type{}, false_type{}); @@ -625,35 +633,31 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto nIter) { static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); }); // move B window to next flat K - move_tile_window(b_flat_dram_windows(nIter), {0, KIterPerWarp * KFlatPerBlockPerIter}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); // prefetch Scale A - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); // move Scale A window to next K move_tile_window(scale_a_dram_window, {0, kKPerBlock / (32 * KXdlPack)}); // prefetch Scale B - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // move Scale B window to next K @@ -667,7 +671,12 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1, MIterPerWarp> + c_warp_tensors; + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, NIterPerWarp, 1>{}( + [&](auto nIter) { clear_tile(c_warp_tensors(mIter)(nIter)); }); + }); statically_indexed_array a_warp_tensor; @@ -688,40 +697,37 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + + // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM 2i - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; @@ -729,39 +735,22 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto inxdl) { constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; - - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds constexpr auto addr = m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -802,81 +791,60 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_ping(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), number{}); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); + + // move B window to next flat K if constexpr(kIter == KIterPerWarp - 1) - move_tile_window(b_flat_dram_windows(nIter), - {0, BlockGemmShape::flatKPerBlock}); + b_flat_dram_offsets(nIter) += b_flat_dram_window.get_load_offset( + tuple, number>{}); }); }); // prefetch Scale A and Scale B (2i+2) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM 2i+1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_pong(number{})(number{}), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -928,78 +896,54 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter) { static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { b_warp_tensor_pong(nIter)(kIter) = load_tile_with_offset( - b_flat_dram_windows(nIter), - make_tuple(number<0>{}, number{})); + b_flat_dram_window, + b_flat_dram_offsets(nIter) + kIter * KFlatPerBlockPerIter); }); }); // prefetch Scale A and Scale B (2i+1) - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_a_dram_windows(mIter_pack)(kIter_pack) = scale_a_dram_window; - move_tile_window(scale_a_dram_windows(mIter_pack)(kIter_pack), - {mIter_pack * MWarp * WG::kM, kIter_pack * (64 / WG::kM)}); - - scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = - load_tile(scale_a_dram_windows(mIter_pack)(kIter_pack)); + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) = load_tile_with_offset( + scale_a_dram_window, + mIter_pack * scale_a_dram_step_m + kIter_pack * scale_a_dram_step_k); }); }); - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - scale_b_dram_windows(nIter_pack)(kIter_pack) = scale_b_dram_window; - move_tile_window(scale_b_dram_windows(nIter_pack)(kIter_pack), - {nIter_pack * NWarp * WG::kN, kIter_pack * (64 / WG::kN)}); - - scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = - load_tile(scale_b_dram_windows(nIter_pack)(kIter_pack)); + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) = load_tile_with_offset( + scale_b_dram_window, + nIter_pack * scale_b_dram_step_n + kIter_pack * scale_b_dram_step_k); }); }); // GEMM loopK-1 - static_for<0, KIterPerWarp / KXdlPack, 1>{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1028,50 +972,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_pong(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_pong(number{})(number{}), scale_a_tile_tensor_pong(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_pong(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1089,50 +1015,32 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}([&](auto kIter_pack) { - static_for<0, MIterPerWarp / MXdlPack, 1>{}([&](auto mIter_pack) { - static_for<0, NIterPerWarp / NXdlPack, 1>{}([&](auto nIter_pack) { + static_for<0, KPackIterPerWarp, 1>{}([&](auto kIter_pack) { + static_for<0, MPackIterPerWarp, 1>{}([&](auto mIter_pack) { + static_for<0, NPackIterPerWarp, 1>{}([&](auto nIter_pack) { static_for<0, KXdlPack, 1>{}([&](auto ikxdl) { static_for<0, MXdlPack, 1>{}([&](auto imxdl) { constexpr auto AwarpIter = imxdl + ikxdl * MXdlPack; + constexpr auto m_iter = mIter_pack * MXdlPack + imxdl; + constexpr auto k_iter = kIter_pack * KXdlPack + ikxdl; static_for<0, NXdlPack, 1>{}([&](auto inxdl) { - // read C warp tensor from C block tensor - CWarpTensor c_warp_tensor; - c_warp_tensor.get_thread_buffer() = - c_block_tile.get_y_sliced_thread_data( - merge_sequences( - sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); - + constexpr auto n_iter = nIter_pack * NXdlPack + inxdl; // warp GEMM WG{}.template operator()( - c_warp_tensor, + c_warp_tensors(number{})(number{}), a_warp_tensor(number{}), - b_warp_tensor_ping(nIter_pack * number{} + inxdl)( - kIter_pack * number{} + ikxdl), + b_warp_tensor_ping(number{})(number{}), scale_a_tile_tensor_ping(mIter_pack)(kIter_pack) .get_thread_buffer()[0], // scale A scale_b_tile_tensor_ping(nIter_pack)(kIter_pack) .get_thread_buffer()[0]); // scale B - - // write C warp tensor into C block tensor - c_block_tile.set_y_sliced_thread_data( - merge_sequences(sequence{}, - c_warp_y_index_zeros), - merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), - c_warp_tensor.get_thread_buffer()); }); // preload next A from lds - constexpr auto addr = (mIter_pack * MXdlPack + imxdl) % 2 + - (kIter_pack * KXdlPack + ikxdl) * 2 + - (mIter_pack * MXdlPack + imxdl) / 2 * 4 + - m_preload; + constexpr auto addr = + m_iter % 2 + k_iter * 2 + m_iter / 2 * 4 + m_preload; if constexpr(addr < (KIterPerWarp * MIterPerWarp) && - (nIter_pack == NIterPerWarp / NXdlPack - 1)) + (nIter_pack == NPackIterPerWarp - 1)) { constexpr auto AmIter = addr % 2 + addr / 4 * 2; constexpr auto AkIter = addr / 2 % 2; @@ -1151,7 +1059,7 @@ struct MXF4FlatmmPipelineAGmemBGmemCRegV1 : FlatmmPipelineAGmemBGmemCRegV1{}; static constexpr auto I1 = number<1>{}; @@ -58,7 +58,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy template CK_TILE_DEVICE static constexpr auto - MakeMXFP4_AAsyncLoadDramDescriptor(const TensorView& naive_view) + MakeMX_AAsyncLoadDramDescriptor(const TensorView& naive_view) { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; @@ -107,7 +107,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_DEVICE static constexpr auto MakeMXFP4_ADramTileDistribution() + CK_TILE_DEVICE static constexpr auto MakeMX_ADramTileDistribution() { using ADataType = remove_cvref_t; @@ -140,7 +140,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_DEVICE static constexpr auto MakeMXFP4_ALdsBlockDescriptor() + CK_TILE_DEVICE static constexpr auto MakeMX_ALdsBlockDescriptor() { using ADataType = remove_cvref_t; using ALayout = remove_cvref_t; @@ -218,7 +218,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXF4_ALDS_TileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ALDS_TileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -255,7 +255,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_BFlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_BFlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -298,7 +298,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_DramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -335,7 +335,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_DramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_DramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; // ck_tile::TileFlatmmShape @@ -372,7 +372,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleA_FlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleA_FlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -394,7 +394,7 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy } template - CK_TILE_HOST_DEVICE static constexpr auto MakeMXFP4_ScaleB_FlatDramTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto MakeMX_ScaleB_FlatDramTileDistribution() { using TileShape = typename Problem::BlockGemmShape; @@ -420,8 +420,8 @@ struct MXF4FlatmmPipelineAgBgCrPolicy : UniversalFlatmmPipelineAgBgCrPolicy { using ADataType = remove_cvref_t; constexpr index_t APackedSize = numeric_traits::PackedSize; - return sizeof(ADataType) * - MakeMXFP4_ALdsBlockDescriptor().get_element_space_size() / APackedSize; + return sizeof(ADataType) * MakeMX_ALdsBlockDescriptor().get_element_space_size() / + APackedSize; } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp index f22d2f599a..d448cdbb93 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp @@ -23,7 +23,8 @@ struct BaseGemmPipelineAgBgCrCompV4 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop) { - return num_loop > PrefetchStages; + constexpr index_t HotLoopGlobalReads = 2; + return num_loop >= (HotLoopGlobalReads + PrefetchStages); } CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 06fc4d5bfa..ad4d0baab2 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -373,8 +373,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { // Need to multiply aquant with accumulated C // - // The accumulated C tile has the standard distribution. For example - // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], + // The accumulated C tile has the standard distribution. For example, a + // 32x32 C lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0], // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0], // [26,0], [27,0]. // @@ -388,35 +388,31 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase // // These scales can be obtained using __builtin_amdgcn_ds_bpermute. - // MIters per warp - constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM; - // Reg block offset based on mIter - constexpr index_t reg_block_offset = - ((mIter / mIters_per_warp) * Traits::AQPerBlock); - - constexpr index_t lane_base_offset = - (mIter % mIters_per_warp) * WarpGemm::kM; - - // Scale tensor offset along K - constexpr index_t src_reg_offset = reg_block_offset + kQScale; - // Directly index into thread buffer corresponding to - // desired row coefficient + // Each thread stores AQPerBlock scale values per M iteration. + constexpr index_t reg_block_offset = mIter * Traits::AQPerBlock; + constexpr index_t src_reg_offset = reg_block_offset + kQScale; auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset]; - constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8; - ; - constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows; - constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane; - // Multiply by 4 because output is stored in tiles of 4 - // x CNLane - constexpr uint32_t row_base = - ((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) + - ((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane); + // Divide M dimension of C Warp tile into groups of + // (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) + // m_base_offset_of_c_row indicates which group the current c_row belongs + // to. + constexpr index_t m_base_offset_of_c_row = + (c_row / WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane) * + (WarpGemm::kCMLane * WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane); + + // M offset of each thread within its group (see comment above) + index_t m_base_offset_of_lane = + (get_lane_id() / WarpGemm::kN * + WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane); + + // M offset wrt. c_row in the subgroup of kCM1PerLane + constexpr index_t m_offset_of_c_row = + c_row & (WarpGemm::WarpGemmAttribute::Impl::kCM1PerLane - 1); - // Lane index to source scale from uint32_t src_lane_idx = - lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows); + m_base_offset_of_c_row + m_base_offset_of_lane + m_offset_of_c_row; return exchange_quant_value_across_lanes(scale_reg, src_lane_idx); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp index dae099af4f..f2142b4fdf 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp @@ -94,21 +94,20 @@ struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding // # of elements per thread constexpr index_t X = XPerTile; - constexpr index_t Y0 = 1; - constexpr index_t Y1 = MIterPerWarp ? MIterPerWarp : 1; - constexpr index_t Y2 = MWarps; - constexpr index_t Y3 = WarpGemm::kM; - static_assert(Y3 >= WarpGemm::kM, + constexpr index_t YR = 1; + constexpr index_t Y0 = MIterPerWarp ? MIterPerWarp : 1; + constexpr index_t Y1 = MWarps; + constexpr index_t Y2 = WarpGemm::kM; + static_assert(Y2 >= WarpGemm::kM, "Scales for all rows must be available within the warp."); - static_assert(Y0 * Y1 * Y2 * Y3 == YPerTile, - "Y0, Y1, Y2, Y3 must cover the blocktile along Y."); + static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 1>>, - tuple, sequence<0, 3>>, + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, sequence<1, 2>, - sequence<1, 0>>{}); + sequence<0, 0>>{}); } } }; diff --git a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp index 71c3dc4cdf..deb4dcb3db 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp @@ -8,7 +8,7 @@ namespace ck_tile { template 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), - make_tuple(NStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Wo_, K_), make_tuple(WoStride, KStride), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Wo_, K_), + make_tuple(NStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> CK_TILE_HOST auto make_wei_grid_desc() const { // GKXC - return make_naive_tensor_descriptor( - make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number{}, I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(K_, C_), make_tuple(C_, I1), number{}, I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(K_, X_, C_), make_tuple(X_ * C_, C_, I1), number{}, I1); + } } template ::type = false> @@ -491,14 +507,22 @@ struct TransformConvBwdDataToGemm { // NWGC const index_t NStride = Wi_ * G_ * C_; - const index_t WiStride = G_ * C_; // GC? + const index_t WiStride = G_ * C_; constexpr auto CStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), - make_tuple(NStride, WiStride, CStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(N_ * Wi_, C_), make_tuple(WiStride, CStride), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Wi_, C_), + make_tuple(NStride, WiStride, CStride), + number{}, + I1); + } } template ::type = false> @@ -512,10 +536,20 @@ struct TransformConvBwdDataToGemm // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), - make_tuple(NStride, HoStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Ho_ * Wo_, K_), + make_tuple(WoStride, KStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, K_), + make_tuple(NStride, HoStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> @@ -528,20 +562,38 @@ struct TransformConvBwdDataToGemm constexpr auto CStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), - make_tuple(NStride, HiStride, WiStride, CStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Hi_ * Wi_, C_), + make_tuple(WiStride, CStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(N_, Hi_, Wi_, C_), + make_tuple(NStride, HiStride, WiStride, CStride), + number{}, + I1); + } } template ::type = false> CK_TILE_HOST auto make_wei_grid_desc() const { // GKYXC - return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_), - make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor( + make_tuple(K_, C_), make_tuple(C_, I1), number{}, I1); + } + else + { + return make_naive_tensor_descriptor(make_tuple(K_, Y_, X_, C_), + make_tuple(C_ * X_ * Y_, C_ * X_, C_, I1), + number{}, + I1); + } } template ::type = false> @@ -555,11 +607,21 @@ struct TransformConvBwdDataToGemm constexpr auto KStride = I1; // TODO Add support for NumGroupsToMerge > 1 - return make_naive_tensor_descriptor( - make_tuple(N_, Do_, Ho_, Wo_, K_), - make_tuple(NStride, DoStride, HoStride, WoStride, KStride), - number{}, - I1); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, K_), + make_tuple(WoStride, KStride), + number{}, + I1); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(N_, Do_, Ho_, Wo_, K_), + make_tuple(NStride, DoStride, HoStride, WoStride, KStride), + number{}, + I1); + } } template ::type = false> @@ -612,103 +674,111 @@ struct TransformConvBwdDataToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_wop_k_grid_desc = - transform_tensor_descriptor(out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wo_, I0, I0), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( - out_n_xdot_wtilde_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + const auto out_n_xdotslice_wtildeslice_k_grid_desc = transform_tensor_descriptor( + out_n_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), - make_merge_transform(make_tuple(N_, WTildeSlice))), - make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_merge_transform(make_tuple(N_, WTildeSlice))), + make_tuple(sequence<1, 3>{}, sequence<0, 2>{}), + make_tuple(sequence<1>{}, sequence<0>{})); - // B: weight tensor comes in K_N - const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + // B: weight tensor comes in K_N + const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( - wei_k_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); + const auto wei_k_xdotslice_c_grid_desc = transform_tensor_descriptor( + wei_k_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<>{}, sequence<2>{})); - const auto wei_gemmn_gemmkraw_grid_desc = - transform_tensor_descriptor(wei_k_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // c: input - const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + // c: input + const auto in_n_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); - const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); + const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{})); - const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); + const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<>{}, sequence<1>{}, sequence<2>{})); - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1>{}, sequence<2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } template ::type = false> @@ -734,39 +804,135 @@ struct TransformConvBwdDataToGemm const auto XDotSlice = integer_divide_ceil(X_ - IdxXTilde_, XTilde_); const auto out_grid_desc = make_out_grid_desc(); - const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); + const auto in_grid_desc = make_in_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<2>{}, + sequence<4>{}, + sequence<5>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<>{}, + sequence<>{}, + sequence<3>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, @@ -774,111 +940,23 @@ struct TransformConvBwdDataToGemm sequence<4>{}, sequence<5>{}), make_tuple(sequence<0>{}, + sequence<>{}, sequence<1>{}, + sequence<>{}, sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{})); + sequence<3>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice))), - make_tuple(sequence<1, 3, 5>{}, sequence<0, 2, 4>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // B: weight tensor comes in K_N - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<3>{}, - sequence<2>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<>{}, - sequence<>{}, - sequence<3>{})); - - const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 2, 0>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // c: input - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<0>{}, - sequence<>{}, - sequence<1>{}, - sequence<>{}, - sequence<2>{}, - sequence<3>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } template ::type = false> @@ -914,45 +992,174 @@ struct TransformConvBwdDataToGemm const auto in_grid_desc = make_in_grid_desc(); const auto wei_grid_desc = make_wei_grid_desc(); - // A: output tensor comes in K_M - const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( - out_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Do_, I0, I0), - make_pad_transform(Ho_, I0, I0), - make_pad_transform(Wo_, I0, I0), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( - out_n_hop_wop_k_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZDot_, DTilde_), - make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), - make_embed_transform(make_tuple(YDot_, HTilde_), - make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, WTilde_), - make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(K_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = - transform_tensor_descriptor( - out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + if constexpr(ConvSpec == ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + return make_tuple(out_grid_desc, wei_grid_desc, in_grid_desc); + } + else + { + // A: output tensor comes in K_M + const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor( + out_grid_desc, make_tuple(make_pass_through_transform(N_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pad_transform(Do_, I0, I0), + make_pad_transform(Ho_, I0, I0), + make_pad_transform(Wo_, I0, I0), make_pass_through_transform(K_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor( + out_n_hop_wop_k_grid_desc, + make_tuple( + make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZDot_, DTilde_), + make_tuple(-ConvDilationD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, HTilde_), + make_tuple(-ConvDilationH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, WTilde_), + make_tuple(-ConvDilationW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(K_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc = + transform_tensor_descriptor( + out_n_ydot_htilde_xdot_wtilde_k_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(K_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<4>{}, + sequence<5>{}, + sequence<6>{}, + sequence<7>{})); + + const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( + out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, + make_tuple( + make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), + make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + // B: weight tensor comes in K_N + const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( + wei_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_embed_transform(make_tuple(ZDot_, ZTilde_), + make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), + make_embed_transform(make_tuple(YDot_, YTilde_), + make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), + make_embed_transform(make_tuple(XDot_, XTilde_), + make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto wei_k_ydotslice_xdotslice_c_grid_desc = + transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, + make_tuple(make_pass_through_transform(K_), + make_slice_transform(ZDot_, I0, ZDotSlice), + make_slice_transform(YDot_, I0, YDotSlice), + make_slice_transform(XDot_, I0, XDotSlice), + make_freeze_transform(IdxZTilde_), + make_freeze_transform(IdxYTilde_), + make_freeze_transform(IdxXTilde_), + make_pass_through_transform(C_)), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<3>{}, + sequence<5>{}, + sequence<2>{}, + sequence<4>{}, + sequence<6>{}, + sequence<7>{}), + make_tuple(sequence<0>{}, + sequence<1>{}, + sequence<2>{}, + sequence<3>{}, + sequence<>{}, + sequence<>{}, + sequence<>{}, + sequence<4>{})); + + const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( + wei_k_ydotslice_xdotslice_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), + make_pass_through_transform(C_)), + make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + // c: input + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_pad_transform(Di_, InLeftPadD_, InRightPadD_), + make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), + make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); + + const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_embed_transform(make_tuple(ZTilde_, DTilde_), + make_tuple(ConvDilationD_, ConvStrideD_)), + make_embed_transform(make_tuple(YTilde_, HTilde_), + make_tuple(ConvDilationH_, ConvStrideH_)), + make_embed_transform(make_tuple(XTilde_, WTilde_), + make_tuple(ConvDilationW_, ConvStrideW_)), + make_pass_through_transform(C_)), + make_tuple( + sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, + sequence<1, 2>{}, + sequence<3, 4>{}, + sequence<5, 6>{}, + sequence<7>{})); + + const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( + in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, + make_tuple(make_pass_through_transform(N_), + make_freeze_transform(IdxZTilde_), + make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), + make_freeze_transform(IdxYTilde_), + make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), + make_freeze_transform(IdxXTilde_), + make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), + make_pass_through_transform(C_)), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, @@ -962,138 +1169,26 @@ struct TransformConvBwdDataToGemm sequence<6>{}, sequence<7>{}), make_tuple(sequence<0>{}, + sequence<>{}, sequence<1>{}, + sequence<>{}, sequence<2>{}, + sequence<>{}, sequence<3>{}, - sequence<4>{}, - sequence<5>{}, - sequence<6>{}, - sequence<7>{})); + sequence<4>{})); - const auto out_gemmm_gemmkraw_grid_desc = transform_tensor_descriptor( - out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice))), - make_tuple(sequence<1, 3, 5, 7>{}, sequence<0, 2, 4, 6>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( + in_n_htildeslice_wtildeslice_c_grid_desc, + make_tuple( + make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), + make_pass_through_transform(C_)), + make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - // B: weight tensor comes in K_N - const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor( - wei_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_embed_transform(make_tuple(ZDot_, ZTilde_), - make_tuple(ConvStrideD_ / GcdStrideDilationD_, I1)), - make_embed_transform(make_tuple(YDot_, YTilde_), - make_tuple(ConvStrideH_ / GcdStrideDilationH_, I1)), - make_embed_transform(make_tuple(XDot_, XTilde_), - make_tuple(ConvStrideW_ / GcdStrideDilationW_, I1)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto wei_k_ydotslice_xdotslice_c_grid_desc = - transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc, - make_tuple(make_pass_through_transform(K_), - make_slice_transform(ZDot_, I0, ZDotSlice), - make_slice_transform(YDot_, I0, YDotSlice), - make_slice_transform(XDot_, I0, XDotSlice), - make_freeze_transform(IdxZTilde_), - make_freeze_transform(IdxYTilde_), - make_freeze_transform(IdxXTilde_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<3>{}, - sequence<5>{}, - sequence<2>{}, - sequence<4>{}, - sequence<6>{}, - sequence<7>{}), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<>{}, - sequence<>{}, - sequence<>{}, - sequence<4>{})); - - const auto wei_gemmn_gemmkraw_grid_desc = transform_tensor_descriptor( - wei_k_ydotslice_xdotslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K_)), - make_pass_through_transform(C_)), - make_tuple(sequence<1, 2, 3, 0>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - // c: input - const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( - in_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_pad_transform(Di_, InLeftPadD_, InRightPadD_), - make_pad_transform(Hi_, InLeftPadH_, InRightPadH_), - make_pad_transform(Wi_, InLeftPadW_, InRightPadW_), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{})); - - const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor( - in_n_hip_wip_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_embed_transform(make_tuple(ZTilde_, DTilde_), - make_tuple(ConvDilationD_, ConvStrideD_)), - make_embed_transform(make_tuple(YTilde_, HTilde_), - make_tuple(ConvDilationH_, ConvStrideH_)), - make_embed_transform(make_tuple(XTilde_, WTilde_), - make_tuple(ConvDilationW_, ConvStrideW_)), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, - sequence<1, 2>{}, - sequence<3, 4>{}, - sequence<5, 6>{}, - sequence<7>{})); - - const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor( - in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc, - make_tuple(make_pass_through_transform(N_), - make_freeze_transform(IdxZTilde_), - make_slice_transform(DTilde_, IDTildeSliceBegin, DTildeSlice), - make_freeze_transform(IdxYTilde_), - make_slice_transform(HTilde_, IHTildeSliceBegin, HTildeSlice), - make_freeze_transform(IdxXTilde_), - make_slice_transform(WTilde_, IWTildeSliceBegin, WTildeSlice), - make_pass_through_transform(C_)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}, - sequence<6>{}, - sequence<7>{}), - make_tuple(sequence<0>{}, - sequence<>{}, - sequence<1>{}, - sequence<>{}, - sequence<2>{}, - sequence<>{}, - sequence<3>{}, - sequence<4>{})); - - const auto in_gemmmraw_gemmnraw_grid_desc = transform_tensor_descriptor( - in_n_htildeslice_wtildeslice_c_grid_desc, - make_tuple(make_merge_transform(make_tuple(N_, DTildeSlice, HTildeSlice, WTildeSlice)), - make_pass_through_transform(C_)), - make_tuple(sequence<0, 1, 2, 3>{}, sequence<4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tuple(out_gemmm_gemmkraw_grid_desc, - wei_gemmn_gemmkraw_grid_desc, - in_gemmmraw_gemmnraw_grid_desc); + return make_tuple(out_gemmm_gemmkraw_grid_desc, + wei_gemmn_gemmkraw_grid_desc, + in_gemmmraw_gemmnraw_grid_desc); + } } IndexType G_, N_, original_N_; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp index bf8536d268..4d9c09f597 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm.hpp @@ -15,6 +15,142 @@ namespace tensor_operation { namespace device { namespace instance { +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__) +void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); +#endif +#if defined(CK_ENABLE_BF16) +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& instances); +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + #if defined(CK_USE_XDL) #if defined(CK_ENABLE_FP16) void add_device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instances( @@ -409,6 +545,81 @@ struct DeviceOperationInstanceFactory> op_ptrs; +#if defined(CK_USE_WMMA) +#if defined(CK_ENABLE_FP16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_FP16 +#if defined(CK_ENABLE_FP16) && defined(CK_ENABLE_FP8) && defined(__gfx12__) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + } + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances(op_ptrs); + } + } +#endif +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances(op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + #if defined(CK_USE_XDL) #if defined(CK_ENABLE_FP16) if constexpr(is_same_v && is_same_v && diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp new file mode 100644 index 0000000000..6d5da9208b --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp @@ -0,0 +1,205 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_splitk_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/loop_scheduler.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AccDataType = F32; +using DsDataType = Empty_Tuple; + +using DsLayout = Empty_Tuple; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3; +static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave; +static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = device::GemmSpecialization::Default; + +// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang`-format on + >; + +// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported +// padding/scheduler/pipeline version combinations +template + typename LayoutInstances, + typename ADataType, // NOTE: type parameters as last so that they can be inferred from the + typename BDataType, // vector argument + typename EDataType> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances(instances, + LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); +} + +// Helper function to add a list of layout instances for instances with matching A/B/E data types +// for all supported padding/scheduler/pipeline version combinations +template + typename LayoutInstances> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); + add_device_operation_instances( + instances, LayoutInstances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index efb6b2580a..7c64fa7850 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -57,7 +57,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Do not build XDL instances if gfx9 targets are not on the target list - if(NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_xdl") + if(((NOT INST_TARGETS MATCHES "gfx9" AND NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_XDL) AND source_name MATCHES "_xdl") message(DEBUG "removing xdl instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -67,7 +67,7 @@ function(add_instance_library INSTANCE_NAME) list(REMOVE_ITEM ARGN "${source}") endif() # Do not build WMMA instances if gfx11 targets are not on the target list - if(NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "_wmma") + if(((NOT INST_TARGETS MATCHES "gfx11" AND NOT INST_TARGETS MATCHES "gfx12") OR FORCE_DISABLE_WMMA) AND source_name MATCHES "_wmma") message(DEBUG "removing wmma instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -88,7 +88,7 @@ function(add_instance_library INSTANCE_NAME) endif() endif() # Do not build WMMA gemm_universal_f8 for any targets except gfx12+ - if(NOT INST_TARGETS MATCHES "gfx12" AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") + if((NOT INST_TARGETS MATCHES "gfx12" OR FORCE_DISABLE_WMMA) AND source_name MATCHES "gemm_wmma_universal" AND source_name MATCHES "_f8_") message(DEBUG "removing gemm_universal_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() @@ -274,7 +274,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only dl instances, but DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12")) + if(("${cmake_instance}" MATCHES "ONLY XDL_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12" OR FORCE_DISABLE_XDL)) message(DEBUG "Found only xdl instances, but gfx9 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -282,7 +282,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only MX instances, but gfx950 is not on the targets list. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) + if(("${cmake_instance}" MATCHES "ONLY WMMA_KERNELS") AND (((NOT INST_TARGETS MATCHES "gfx11") AND (NOT INST_TARGETS MATCHES "gfx12")) OR FORCE_DISABLE_WMMA)) message(DEBUG "Found only wmma instances, but gfx11 is not on the targets list. Skipping.") set(add_inst 0) endif() @@ -290,7 +290,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found only xdl and dl instances, but gfx9 is not on the targets listand DL_KERNELS is not set. Skipping.") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND (NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12")) + if(("${cmake_instance}" MATCHES "ONLY XDL_AND_WMMA_KERNELS") AND ((NOT INST_TARGETS MATCHES "gfx9|gfx11|gfx12") OR (FORCE_DISABLE_XDL AND FORCE_DISABLE_WMMA))) message(DEBUG "Found only xdl and wmma instances, but gfx11 and gfx9 are not on the targets list. Skipping.") set(add_inst 0) endif() @@ -333,20 +333,22 @@ FOREACH(subdir_path ${dir_list}) if((add_inst EQUAL 1)) get_filename_component(target_dir ${subdir_path} NAME) add_subdirectory(${target_dir}) - if("${cmake_instance}" MATCHES "gemm") - list(APPEND CK_DEVICE_GEMM_INSTANCES $) - elseif("${cmake_instance}" MATCHES "conv") - list(APPEND CK_DEVICE_CONV_INSTANCES $) - elseif("${cmake_instance}" MATCHES "mha") - list(APPEND CK_DEVICE_MHA_INSTANCES $) - elseif("${cmake_instance}" MATCHES "contr") - list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $) - elseif("${cmake_instance}" MATCHES "reduce") - list(APPEND CK_DEVICE_REDUCTION_INSTANCES $) - else() - list(APPEND CK_DEVICE_OTHER_INSTANCES $) - endif() - message(DEBUG "add_instance_directory ${subdir_path}") + if (TARGET device_${target_dir}_instance) + if("${cmake_instance}" MATCHES "gemm") + list(APPEND CK_DEVICE_GEMM_INSTANCES $) + elseif("${cmake_instance}" MATCHES "conv") + list(APPEND CK_DEVICE_CONV_INSTANCES $) + elseif("${cmake_instance}" MATCHES "mha") + list(APPEND CK_DEVICE_MHA_INSTANCES $) + elseif("${cmake_instance}" MATCHES "contr") + list(APPEND CK_DEVICE_CONTRACTION_INSTANCES $) + elseif("${cmake_instance}" MATCHES "reduce") + list(APPEND CK_DEVICE_REDUCTION_INSTANCES $) + else() + list(APPEND CK_DEVICE_OTHER_INSTANCES $) + endif() + message(DEBUG "add_instance_directory ${subdir_path}") + endif() else() message(DEBUG "skip_instance_directory ${subdir_path}") endif() diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt index b2e7e51a3c..ba54c6ffb3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS add_instance_library(device_grouped_gemm_instance device_grouped_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp @@ -36,4 +36,17 @@ add_instance_library(device_grouped_gemm_instance device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_bf16_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_multiple_d_splitk_xdl_two_stage_bf16_i8_bf16_mk_nk_mn_instance.cpp + + device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp + + device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp + + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp + device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp ) diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..6f8b31e663 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..2839890dcf --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..c41dbdfc7b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..55d1163900 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + BF16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp new file mode 100644 index 0000000000..ea7eb0d615 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Row, + device_grouped_gemm_wmma_universal_km_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp new file mode 100644 index 0000000000..816188c7ff --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_km_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_universal_instances< + F16, + Col, + Col, + device_grouped_gemm_wmma_universal_km_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..6680002d47 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Row, + device_grouped_gemm_wmma_universal_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 0000000000..3e82899834 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_universal_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + F16, + Row, + Col, + device_grouped_gemm_wmma_universal_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e93e9dff4a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ADataType = F16; +using BDataType = F8; +using EDataType = F16; + +template +using device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + Row, + Row, + device_grouped_gemm_wmma_universal_f16_f8_f16_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 0000000000..e8f043d1f8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,57 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_splitk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using ADataType = F8; +using BDataType = F16; +using EDataType = F16; + +template +using device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_CShuffleV3< Row, Row, DsLayout, ELayout, ADataType, BDataType, AccDataType, EDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + + add_device_grouped_gemm_wmma_universal_instances< + Row, + Row, + device_grouped_gemm_wmma_universal_f8_f16_f16_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_impl.hpp index 03a2ed3186..0ee0ee4c2e 100644 --- a/profiler/include/profiler/profile_grouped_gemm_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_impl.hpp @@ -42,10 +42,11 @@ bool profile_grouped_gemm_impl(int do_verification, const std::vector& StrideAs, const std::vector& StrideBs, const std::vector& StrideCs, - const std::vector& kbatches = {}, - int n_warmup = 1, - int n_iter = 10, - int instance_index = -1) + const std::vector& kbatches = {}, + int n_warmup = 1, + int n_iter = 10, + int instance_index = -1, + bool fail_if_no_supported_instance = false) { bool pass = true; // TODO: Fixme - we do not pass compute data type here but need it @@ -225,6 +226,7 @@ bool profile_grouped_gemm_impl(int do_verification, } } // profile device GEMM instances + int instances_supporting_all_batch_sizes = 0; for(auto& gemm_ptr : op_ptrs) { auto argument_ptr = @@ -268,6 +270,7 @@ bool profile_grouped_gemm_impl(int do_verification, kbatch_list = kbatches; } + bool all_batch_sizes_supported = true; for(std::size_t j = 0; j < kbatch_list.size(); j++) { auto kbatch_curr = kbatch_list[j]; @@ -367,10 +370,30 @@ bool profile_grouped_gemm_impl(int do_verification, } else { + all_batch_sizes_supported = false; std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" << std::endl; } } + + // If all batch sizes were supported by this instance, the instance can be marked as + // 'supported' for this problem + if(all_batch_sizes_supported) + { + ++instances_supporting_all_batch_sizes; + } + } + + // Warn if not a single instance was supported + if(instances_supporting_all_batch_sizes == 0) + { + std::cout << "Warning! No instance found that supported all of the batch sizes." + << std::endl; + + if(fail_if_no_supported_instance) + { + return false; + } } if(time_kernel) @@ -384,6 +407,7 @@ bool profile_grouped_gemm_impl(int do_verification, std::cout << "grouped_gemm_instance (" << instance_index << "/" << num_kernel << "): Passed" << std::endl; } + return pass; } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 5e610cb76b..685f52cdac 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -53,6 +53,13 @@ struct GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; +struct GemmConfigPrefill : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; +}; + struct GemmConfigPreshuffleQuant : public GemmConfigBase { static constexpr bool PreshuffleQuant = true; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp index a75d871421..07aed62804 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -39,6 +39,12 @@ using AQuantTypes = ::testing::Types< std::tuple, std::tuple, + // PreshuffleQuant = false && TransposeC = false && Prefill + std::tuple, + std::tuple, + std::tuple, + std::tuple, + // PreshuffleQuant = false && TransposeC = true std::tuple, std::tuple, diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index f78ce29daa..c6b5180013 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -3,10 +3,15 @@ add_custom_target(test_grouped_gemm) -add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk_xdl.cpp) -if(result EQUAL 0) - target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) - add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) +# NOTE: We test for XDL/WMMA support here instead of relying on the usual pattern matching in the parent CMakeLists. This is necessary +# as these tests are universal and dont have "xdl" or "wmma" in their name to signify their target arch. But they will fail to link +# the instance library if there's no instances present for the current arch. +if (CK_USE_XDL OR CK_USE_WMMA) + add_gtest_executable(test_grouped_gemm_splitk test_grouped_gemm_splitk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_splitk PRIVATE utility device_grouped_gemm_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp index 1683e16323..56fb758f89 100644 --- a/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_interface_xdl.cpp @@ -9,6 +9,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "test_grouped_gemm_util.hpp" +#include "test_grouped_gemm_interface_xdl.hpp" class TestGGemmSplitKInterface_MKNKMN : public ::testing::Test { diff --git a/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp b/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp new file mode 100644 index 0000000000..a04d13c1ea --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_interface_xdl.hpp @@ -0,0 +1,205 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/stream_config.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/utility/number.hpp" +#include "profiler/profile_grouped_gemm_impl.hpp" + +namespace ck { +namespace test { + +template +struct DeviceGroupedGemmSplitkInstanceWrapper +{ + using F16 = half_t; + using F32 = float; + using Row = ck::tensor_layout::gemm::RowMajor; + using Col = ck::tensor_layout::gemm::ColumnMajor; + using PassThrough = tensor_operation::element_wise::PassThrough; + + using EmptyTuple = ck::Tuple<>; + + template + using S = ck::Sequence; + + template + using I = ck::Number; + + using ABlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcAccessOrder = + std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; + using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; + using ABlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<8>, I<2>>; + using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; + + using BBlockTransferThreadClusterArrageOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcAccessOrder = + std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; + using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; + using BBlockTransferDstScalarPerVector_K1 = + std::conditional_t, I<2>, I<8>>; + using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; + + using DeviceGroupedGemmSplitKInstance = + tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< + ALayout, + BLayout, + EmptyTuple, + ELayout, + F16, + F16, + F32, + F16, + EmptyTuple, + F16, + PassThrough, + PassThrough, + PassThrough, + GemmSpec, + 1, + 128, + 128, + 128, + KPerBlock, + K1, + K1, + 16, + 16, + 8, + 4, + S<1, 4, 16, 1>, + ABlockTransferThreadClusterArrageOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim::value, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1::value, + ABlockLdsAddExtraM::value, + S<1, 4, 16, 1>, + BBlockTransferThreadClusterArrageOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim::value, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1::value, + BBlockLdsAddExtraM::value, + 1, + 1, + S<1, 16, 1, 8>, + CDEBlockTransferScalarPerVector_NPerBlock>; + + bool IsSupported(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(&argument, kbatch); + } + + return ggemm_instance.IsSupportedArgument(argument); + } + + float Run(const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + int kbatch = 1) const + { + std::size_t n_groups = Ms.size(); + EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && + StrideBs.size() == n_groups && StrideCs.size() == n_groups) + << "The number of groups is not consistent!"; + + std::vector gemm_descs; + + for(std::size_t i = 0; i < n_groups; ++i) + { + gemm_descs.push_back(tensor_operation::device::GemmDesc{ + Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + } + + std::vector p_As(n_groups, nullptr); + std::vector p_Bs(n_groups, nullptr); + std::vector p_Cs(n_groups, nullptr); + auto p_Ds = std::vector>{}; + + auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; + auto argument = ggemm_instance.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); + if(kbatch > 1) + { + ggemm_instance.SetKBatchSize(&argument, kbatch); + } + if(kbatch > 1 && ck::is_gfx11_supported()) + { + EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument)); + return 0; + } + else + { + EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); + auto invoker = ggemm_instance.MakeInvoker(); + DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); + ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); + return invoker.Run(argument, StreamConfig{nullptr, false}); + } + } +}; + +} // namespace test +} // namespace ck diff --git a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp b/test/grouped_gemm/test_grouped_gemm_splitk.cpp similarity index 62% rename from test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp rename to test/grouped_gemm/test_grouped_gemm_splitk.cpp index c237fd562e..968bea2109 100644 --- a/test/grouped_gemm/test_grouped_gemm_splitk_xdl.cpp +++ b/test/grouped_gemm/test_grouped_gemm_splitk.cpp @@ -24,21 +24,48 @@ using Col = ck::tensor_layout::gemm::ColumnMajor; template class TestGroupedGemm : public ck::test::TestGroupedGemm { + public: + void SetUp() override + { + ck::test::TestGroupedGemm::SetUp(); + +#if defined(CK_USE_WMMA) + // The old XDL tests didn't fail if instances were not supported, so we want to keep that + // behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a + // specific case is not supported + this->fail_if_no_supported_instances_ = + ck::is_gfx11_supported() || ck::is_gfx12_supported(); +#endif + } }; // clang-format off using KernelTypes = ::testing::Types< + +#if defined(CK_USE_WMMA) + // WWMA only. No reason to not have it for XDL, but the instance was not defined and it was not in the original test. + std::tuple< Col, Col, Row, BF16, BF16, BF16>, +#endif + +#if defined(CK_USE_XDL) && defined(__gfx9__) + // XDL only at the moment, instances for WMMA not defined + std::tuple< Row, Row, Row, BF16, I8, BF16>, + std::tuple< Row, Col, Row, BF16, I8, BF16>, +#endif + +#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__)) + std::tuple< Row, Row, Row, F8, F16, F16>, + std::tuple< Row, Row, Row, F16, F8, F16>, +#endif + std::tuple< Row, Row, Row, F16, F16, F16>, std::tuple< Row, Col, Row, F16, F16, F16>, std::tuple< Col, Row, Row, F16, F16, F16>, std::tuple< Col, Col, Row, F16, F16, F16>, + std::tuple< Row, Row, Row, BF16, BF16, BF16>, std::tuple< Row, Col, Row, BF16, BF16, BF16>, - std::tuple< Col, Row, Row, BF16, BF16, BF16>, - std::tuple< Row, Row, Row, BF16, I8, BF16>, - std::tuple< Row, Col, Row, BF16, I8, BF16>, - std::tuple< Row, Row, Row, F16, F8, F16>, - std::tuple< Row, Row, Row, F8, F16, F16> + std::tuple< Col, Row, Row, BF16, BF16, BF16> >; // clang-format on diff --git a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc index 16c4ad5909..84558c89f9 100644 --- a/test/grouped_gemm/test_grouped_gemm_ut_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_ut_cases.inc @@ -65,6 +65,13 @@ TYPED_TEST(TestGroupedGemm, MNKPadded) TYPED_TEST(TestGroupedGemm, TestLargeKBatch) { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still run the tests for fp32, but we currently don't have instances for + // it so we disable it entirely + if(ck::is_gfx11_supported()) + GTEST_SKIP() << "Split-K not supported for FP16/BF16 on GFX11 due to missing atomic add " + "instructions"; + const std::vector Ms{188, 210}; constexpr int N = 768; constexpr int K = 4096; diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 912066ee80..6ee6465cc4 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -11,16 +11,7 @@ #include #include "ck/ck.hpp" -#include "ck/stream_config.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/utility/tuple.hpp" -#include "ck/utility/number.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" extern ck::index_t param_mask; @@ -41,7 +32,7 @@ std::string serialize_range(const Range& range) return std::string(str.begin(), str.end() - 2); } -template +template class TestGroupedGemm : public testing::Test { protected: @@ -62,9 +53,26 @@ class TestGroupedGemm : public testing::Test static constexpr bool bench_ = false; // measure kernel performance static constexpr int n_warmup_ = 0; static constexpr int n_iter_ = 1; + + bool fail_if_no_supported_instances_ = FailIfNoSupportedInstances; std::vector k_batches_; - void SetUp() override { k_batches_ = {1, 2, 3, 5, 8}; } + void SetUp() override + { + constexpr bool require_16bit_atomic_add = + std::is_same_v || std::is_same_v; + if(require_16bit_atomic_add && ck::is_gfx11_supported()) + { + // gfx11 does not support split-K due to missing atomic add for fp16/bf16 + // Technically, we could still use split-K for fp32, but we currently don't have + // instances for it so we disable it entirely + k_batches_ = {1}; + } + else + { + k_batches_ = {1, 2, 3, 5, 8}; + } + } private: template @@ -132,204 +140,31 @@ class TestGroupedGemm : public testing::Test const std::vector& StrideCs, const std::vector& kbatches) { - bool pass = ck::profiler::profile_grouped_gemm_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index); + bool pass = + ck::profiler::profile_grouped_gemm_impl(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); EXPECT_TRUE(pass); } }; -template -struct DeviceGroupedGemmSplitkInstanceWrapper -{ - using F16 = half_t; - using F32 = float; - using Row = ck::tensor_layout::gemm::RowMajor; - using Col = ck::tensor_layout::gemm::ColumnMajor; - using PassThrough = tensor_operation::element_wise::PassThrough; - - using EmptyTuple = ck::Tuple<>; - - template - using S = ck::Sequence; - - template - using I = ck::Number; - - using ABlockTransferThreadClusterArrageOrder = - std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; - using ABlockTransferSrcAccessOrder = - std::conditional_t, S<0, 2, 1, 3>, S<0, 1, 3, 2>>; - using ABlockTransferSrcVectorDim = std::conditional_t, I<3>, I<2>>; - using ABlockTransferDstScalarPerVector_K1 = - std::conditional_t, I<8>, I<2>>; - using ABlockLdsAddExtraM = std::conditional_t, I<1>, I<0>>; - - using BBlockTransferThreadClusterArrageOrder = - std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; - using BBlockTransferSrcAccessOrder = - std::conditional_t, S<0, 1, 3, 2>, S<0, 2, 1, 3>>; - using BBlockTransferSrcVectorDim = std::conditional_t, I<2>, I<3>>; - using BBlockTransferDstScalarPerVector_K1 = - std::conditional_t, I<2>, I<8>>; - using BBlockLdsAddExtraM = std::conditional_t, I<0>, I<1>>; - - using DeviceGroupedGemmSplitKInstance = - tensor_operation::device::DeviceGroupedGemmXdlSplitKCShuffle< - ALayout, - BLayout, - EmptyTuple, - ELayout, - F16, - F16, - F32, - F16, - EmptyTuple, - F16, - PassThrough, - PassThrough, - PassThrough, - GemmSpec, - 1, - 128, - 128, - 128, - KPerBlock, - K1, - K1, - 16, - 16, - 8, - 4, - S<1, 4, 16, 1>, - ABlockTransferThreadClusterArrageOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim::value, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_K1::value, - ABlockLdsAddExtraM::value, - S<1, 4, 16, 1>, - BBlockTransferThreadClusterArrageOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim::value, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_K1::value, - BBlockLdsAddExtraM::value, - 1, - 1, - S<1, 16, 1, 8>, - CDEBlockTransferScalarPerVector_NPerBlock>; - - bool IsSupported(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1) const - { - std::size_t n_groups = Ms.size(); - EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && - StrideBs.size() == n_groups && StrideCs.size() == n_groups) - << "The number of groups is not consistent!"; - - std::vector gemm_descs; - - for(std::size_t i = 0; i < n_groups; ++i) - { - gemm_descs.push_back(tensor_operation::device::GemmDesc{ - Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - } - - std::vector p_As(n_groups, nullptr); - std::vector p_Bs(n_groups, nullptr); - std::vector p_Cs(n_groups, nullptr); - auto p_Ds = std::vector>{}; - - auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; - auto argument = ggemm_instance.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); - if(kbatch > 1) - { - ggemm_instance.SetKBatchSize(&argument, kbatch); - } - - return ggemm_instance.IsSupportedArgument(argument); - } - - float Run(const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - int kbatch = 1) const - { - std::size_t n_groups = Ms.size(); - EXPECT_TRUE(Ns.size() == n_groups && Ks.size() == n_groups && StrideAs.size() == n_groups && - StrideBs.size() == n_groups && StrideCs.size() == n_groups) - << "The number of groups is not consistent!"; - - std::vector gemm_descs; - - for(std::size_t i = 0; i < n_groups; ++i) - { - gemm_descs.push_back(tensor_operation::device::GemmDesc{ - Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); - } - - std::vector p_As(n_groups, nullptr); - std::vector p_Bs(n_groups, nullptr); - std::vector p_Cs(n_groups, nullptr); - auto p_Ds = std::vector>{}; - - auto ggemm_instance = DeviceGroupedGemmSplitKInstance{}; - auto argument = ggemm_instance.MakeArgument( - p_As, p_Bs, p_Ds, p_Cs, gemm_descs, PassThrough{}, PassThrough{}, PassThrough{}); - if(kbatch > 1) - { - ggemm_instance.SetKBatchSize(&argument, kbatch); - } - if(kbatch > 1 && ck::is_gfx11_supported()) - { - EXPECT_FALSE(ggemm_instance.IsSupportedArgument(argument)); - return 0; - } - else - { - EXPECT_TRUE(ggemm_instance.IsSupportedArgument(argument)); - auto invoker = ggemm_instance.MakeInvoker(); - DeviceMem dev_gemm_kargs(ggemm_instance.GetDeviceKernelArgSize(&argument)); - ggemm_instance.SetDeviceKernelArgs(&argument, dev_gemm_kargs.GetDeviceBuffer()); - return invoker.Run(argument, StreamConfig{nullptr, false}); - } - } -}; - } // namespace test } // namespace ck