diff --git a/README.md b/README.md index 32688b6574..01d523c2ab 100644 --- a/README.md +++ b/README.md @@ -93,13 +93,44 @@ Docker images are available on [DockerHub](https://hub.docker.com/r/rocm/composa want to build the library for a list of different architectures, you should use the `GPU_ARCHS` build argument, for example `GPU_ARCHS=gfx908;gfx1030;gfx1100;gfx942`. -4. Build the entire CK library: + **Convenience script for development builds:** + + Alternatively, you can use the provided convenience script `script/cmake-ck-dev.sh` which automatically + configures CK for development with sensible defaults. In the build directory: + + ```bash + ../script/cmake-ck-dev.sh + ``` + + This script: + * Cleans CMake cache files before configuring + * Sets `BUILD_DEV=ON` for development mode + * Defaults to GPU targets: `gfx908;gfx90a;gfx942` + * Enables verbose makefile output + * Sets additional compiler flags for better error messages + + By default, it considers the parent directory to be the project source directory. + + You can specify the source directory as the first argument. + You can specify custom GPU targets (semicolon-separated) as the second argument: + + ```bash + ../script/cmake-ck-dev.sh .. gfx1100 + ``` + + Or pass additional cmake arguments: + + ```bash + ../script/cmake-ck-dev.sh .. gfx90a -DCMAKE_BUILD_TYPE=Release + ``` + +5. Build the entire CK library: ```bash make -j"$(nproc)" ``` -5. Install CK: +6. Install CK: ```bash make -j install diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index f92f6ef87a..3c26661c84 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -68,7 +68,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) else if(data_type == "pk_int4_t") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, ck_tile::half_t, diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index dbed40800e..6d833fbd7a 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -12,13 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 -#define CK_TILE_PIPELINE_COMPUTE_V6 5 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 6 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -69,7 +62,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; @@ -91,9 +84,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -111,8 +104,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -131,8 +124,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -150,8 +143,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -169,8 +162,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -190,8 +183,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -213,8 +206,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -232,8 +225,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -252,7 +245,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaveGroups = 2; }; @@ -272,7 +265,7 @@ struct GemmConfigComputeV6 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V6; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V6; static constexpr ck_tile::index_t NumWaveGroups = 1; }; @@ -291,13 +284,13 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -315,13 +308,13 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; - static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr int N_Repeat = N_Tile / N_Warp_Tile / N_Warp; + static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; template @@ -465,11 +458,11 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -478,7 +471,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -487,7 +480,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -496,7 +489,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; @@ -505,7 +498,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6; @@ -514,7 +507,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; diff --git a/example/ck_tile/03_gemm/universal_gemm.cpp b/example/ck_tile/03_gemm/universal_gemm.cpp index f9a7263a5f..a8a7288a3d 100644 --- a/example/ck_tile/03_gemm/universal_gemm.cpp +++ b/example/ck_tile/03_gemm/universal_gemm.cpp @@ -58,7 +58,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) else if(data_type == "fp16i4") { // TODO: Add support for bhalf_t ADataType - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, @@ -73,7 +73,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) } else if(data_type == "fp8i4") { - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, @@ -88,7 +88,7 @@ int run_gemm_example(ck_tile::ArgParser& arg_parser) } else if(data_type == "bf8i4") { - if constexpr(GemmConfig::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + if constexpr(GemmConfig::Pipeline == ck_tile::GemmPipeline::COMPUTE_V3) { return run_gemm_example_prec_type, Invoker, diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.hpp b/example/ck_tile/16_batched_gemm/batched_gemm.hpp index 33da0bf0a5..c0935a0e46 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.hpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.hpp @@ -11,10 +11,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - struct GemmConfigMemory { // Memory friendly for Interwave scheduler @@ -30,9 +26,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -50,9 +46,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -71,9 +67,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -91,16 +87,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -109,7 +105,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -118,7 +114,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 57d3f224d8..049957cbfd 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -11,11 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_PRESHUFFLE_V2 4 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -87,7 +82,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool Persistent = true; @@ -109,8 +104,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 1; }; @@ -132,8 +127,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -155,8 +150,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -178,12 +173,12 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr bool kPadK = true; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; }; template @@ -201,12 +196,12 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr bool kPadK = true; + static constexpr int kBlockPerCu = 2; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr bool kPadK = true; }; template @@ -226,8 +221,8 @@ struct GemmConfigComputeV4_Wmma : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; static constexpr int kBlockPerCu = 2; }; @@ -249,18 +244,18 @@ struct GemmConfigPreshuffleDecode_Wmma : public GemmConfigBase static constexpr bool kPadK = true; - static constexpr int kBlockPerCu = 1; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2; - static constexpr bool Preshuffle = true; - static constexpr bool DoubleSmemBuffer = true; + static constexpr int kBlockPerCu = 1; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::PRESHUFFLE_V2; + static constexpr bool Preshuffle = true; + static constexpr bool DoubleSmemBuffer = true; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -269,7 +264,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -278,7 +273,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -287,7 +282,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 12d70eecb6..81c0b654e2 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -11,10 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - template constexpr ck_tile::index_t get_k_warp_tile() { @@ -44,8 +40,8 @@ struct GemmConfigBase static constexpr int kBlockPerCu = 1; static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr bool Preshuffle = false; // currently preshuffle == true is not supported yet static constexpr bool Persistent = false; // currently persistent == true is not supported yet static constexpr bool DoubleSmemBuffer = @@ -67,10 +63,10 @@ struct GemmConfigMemory : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr bool Persistent = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr bool Persistent = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 : public GemmConfigBase @@ -88,10 +84,10 @@ struct GemmConfigV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 : public GemmConfigBase { @@ -109,10 +105,10 @@ struct GemmConfigV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool Persistent = true; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool Persistent = true; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma : public GemmConfigBase @@ -130,16 +126,16 @@ struct GemmConfigV3_Wmma : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -148,7 +144,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -157,7 +153,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp index a7ae227627..8a621cd4be 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.hpp @@ -7,12 +7,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 - using ADataType = ck_tile::half_t; using BDataType = ck_tile::half_t; using D0DataType = ck_tile::half_t; @@ -36,9 +33,9 @@ struct GemmConfigMemory static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 8; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; struct GemmConfigV3 @@ -56,9 +53,9 @@ struct GemmConfigV3 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV4 @@ -77,9 +74,9 @@ struct GemmConfigV4 static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; struct GemmConfigV3_Wmma @@ -97,16 +94,16 @@ struct GemmConfigV3_Wmma static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -115,7 +112,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -124,7 +121,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/example/ck_tile/20_grouped_convolution/conv_configs.hpp b/example/ck_tile/20_grouped_convolution/conv_configs.hpp index 1be6080383..c688215280 100644 --- a/example/ck_tile/20_grouped_convolution/conv_configs.hpp +++ b/example/ck_tile/20_grouped_convolution/conv_configs.hpp @@ -12,11 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 - struct ConvConfigBase { static constexpr bool kPadM = true; @@ -37,7 +32,7 @@ struct ConvConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; static constexpr bool Preshuffle = false; static constexpr bool TiledMMAPermuteN = false; @@ -61,9 +56,9 @@ struct ConvConfigMemoryInterwave : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -81,8 +76,8 @@ struct ConvConfigMemoryIntrawave : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -101,8 +96,8 @@ struct ConvConfigComputeV3 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -120,8 +115,8 @@ struct ConvConfigComputeV3_1 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -139,8 +134,8 @@ struct ConvConfigComputeV3_2 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -160,8 +155,8 @@ struct ConvConfigComputeV3_WMMA : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -183,8 +178,8 @@ struct ConvConfigComputeV4 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -202,8 +197,8 @@ struct ConvConfigComputeV4_1 : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -222,7 +217,7 @@ struct ConvConfigComputeV5 : public ConvConfigBase static constexpr ck_tile::index_t K_Warp_Tile = 16; static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; @@ -245,8 +240,8 @@ struct ConvConfigComputeV3_merged_groups : public ConvConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 32; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumGroupsToMerge = 2; }; @@ -294,11 +289,11 @@ struct DataTypeTraits static constexpr const char* name = "bf16"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -307,7 +302,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -316,7 +311,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -325,7 +320,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 89922fc07b..d9a65b9639 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -112,89 +112,87 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; + return ave_time; + }; // ===================================================================== // Split-K lambda @@ -202,11 +200,11 @@ struct GroupedConvolutionForwardInvoker const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { if(args.k_batch == 1) { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); + Run.template operator()(has_hot_loop_, tail_number_, MemoryOpSet{}); } else { - Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); + Run.template operator()(has_hot_loop_, tail_number_, MemoryOpAtomicAdd{}); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 6a76057d73..2e98c0863b 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -53,7 +53,10 @@ struct GroupedConvolutionForwardInvoker OutLayout, VectorSizeA, VectorSizeB, - VectorSizeC>; + VectorSizeC, + 1, /*NumGroupsToMerge*/ + ck_tile::element_wise::PassThrough, + true /*EnableSplitImage*/>; using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits< GemmConfig::kPadM, @@ -238,68 +241,64 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Kernel launch lambda: Uses EnableSplitImage based on layout support // ===================================================================== - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - // Use split-image kernel if layout supports it, otherwise use regular kernel - using Kernel = ck_tile::GroupedConvolutionForwardKernel; + // Use split-image kernel if layout supports it, otherwise use regular kernel + using Kernel = ck_tile::GroupedConvolutionForwardKernel; - // Create kargs - auto kargs = Kernel::MakeKernelArgs(args); + // Create kargs + auto kargs = Kernel::MakeKernelArgs(args); - // Populate split-image metadata ONLY if using split-image kernel - if constexpr(EnableSplitImage) - { + // Populate split-image metadata ONLY if using split-image kernel kargs.num_spatial_pieces = total_pieces; kargs.split_image.total_d = total_d; kargs.split_image.total_h = total_h; @@ -320,41 +319,35 @@ struct GroupedConvolutionForwardInvoker temp_pieces[i].h_size, temp_pieces[i].w_size}; } - } - // Calculate grid: use total_blocks for split-image, or normal GridSize for regular - const dim3 grids = [&]() { - if constexpr(EnableSplitImage) - return dim3(total_blocks, kargs.GemmBatch, kargs.n_splits); - else - return Kernel::GridSize(kargs); - }(); - const dim3 blocks = Kernel::BlockSize(); + // Calculate grid: use total_blocks for split-image, or normal GridSize for regular + const dim3 grids = dim3(total_blocks, kargs.GemmBatch, kargs.n_splits); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - return ave_time; - }; + return ave_time; + }; // ===================================================================== // Step 4: Dispatch kernel (split-image or regular based on decision) diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 91fa444f0d..b0e2c02973 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -11,7 +11,9 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/grouped_convolution.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "conv_configs.hpp" + using MemoryOpSet = std::integral_constant; using MemoryOpAtomicAdd = std::integral_constant +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -128,7 +121,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -137,7 +130,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; diff --git a/experimental/builder/include/ck_tile/builder/CMakeLists.txt b/experimental/builder/include/ck_tile/builder/CMakeLists.txt index f20b5d54ec..45723c3680 100644 --- a/experimental/builder/include/ck_tile/builder/CMakeLists.txt +++ b/experimental/builder/include/ck_tile/builder/CMakeLists.txt @@ -1 +1 @@ -# Empty placeholder until we add library code. +#Empty placeholder until we add library code. diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index ebd168a7d0..ea4e6de6fd 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -425,6 +425,11 @@ struct DeviceGemm_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmV2BPreshuffle 0) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp index 78546c4f99..6ce2f63e3a 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp @@ -40,14 +40,22 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared, - karg); + karg, + k_id, + Kt); } #else ignore = karg; @@ -74,15 +82,23 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, p_shared_0, p_shared_1, - karg); + karg, + k_id, + Kt); } #else ignore = karg; @@ -658,25 +674,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA; } - if constexpr(is_same_v) - { - b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - if constexpr(!PermuteB) - { - // b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize; - - b_k_split_offset = blockIdx.z * karg.KRead * NLane / BPackedSize; - } - else - { - const int k0_offset = karg.KRead * karg.N; - b_k_split_offset = blockIdx.z * k0_offset / BPackedSize; - } - } - if(blockIdx.z < static_cast(karg.KBatch - 1)) { karg.K = karg.KRead; @@ -697,7 +694,6 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle } index_t a_k_split_offset; - index_t b_k_split_offset; index_t c_reduce_offset; }; @@ -900,6 +896,11 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); + if constexpr(NXdlPerWave % CShuffleNXdlPerWavePerShuffle != 0) + { + return false; + } + if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MKPadding || @@ -1134,7 +1135,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1, const BGridDesc_BPreshuffled& b_grid_desc_bpreshuffled, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& - c_grid_desc_mblock_mperblock_nblock_nperblock) + c_grid_desc_mblock_mperblock_nblock_nperblock, + const index_t k_id) { const auto a_grid_buf = make_dynamic_buffer( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1226,7 +1228,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle true>(b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1465,10 +1467,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle const BDataType* p_b_grid, CDataType* p_c_grid, void* p_shared, - const Problem& problem) + const Problem& problem, + const index_t k_id, + const index_t Kt) { index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = @@ -1491,7 +1495,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bpreshuffled, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id); } template ( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -1606,7 +1612,7 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle true>(b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPack * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1849,10 +1855,12 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle CDataType* p_c_grid, void* p_shared_0, void* p_shared_1, - const Problem& problem) + const Problem& problem, + const index_t k_id, + const index_t Kt) { index_t BN0Shuffled = CalculateBN0Shuffled(problem.N); - index_t BK0Shuffled = CalculateBK0Shuffled(problem.K); + index_t BK0Shuffled = CalculateBK0Shuffled(Kt); const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); const auto b_grid_desc_bpreshuffled = @@ -1877,7 +1885,8 @@ struct GridwiseGemm_xdl_cshuffle_v3_b_preshuffle problem, a_grid_desc_ak0_m_ak1, b_grid_desc_bpreshuffled, - c_grid_desc_mblock_mperblock_nblock_nperblock); + c_grid_desc_mblock_mperblock_nblock_nperblock, + k_id); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp index 2e95ec0d52..f2f1530599 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d_b_preshuffle.hpp @@ -43,18 +43,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, karg, karg.a_element_op, karg.b_element_op, - karg.c_element_op); + karg.c_element_op, + k_id, + Kt); } #else ignore = karg; @@ -79,11 +87,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + // Full K needed for matrix B + const index_t Kt = karg.K; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + const index_t num_k_per_block = GridwiseGemm::CalculateBK0Shuffled(karg.K); + const index_t k_id = blockIdx.z * num_k_per_block; + GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, - karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_grid, karg.p_ds_grid, karg.p_c_grid, p_shared, @@ -91,7 +105,9 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) karg, karg.a_element_op, karg.b_element_op, - karg.c_element_op); + karg.c_element_op, + k_id, + Kt); } #else ignore = karg; @@ -691,16 +707,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_k_split_offset = k_id * karg.KRead * karg.StrideA; } - if constexpr(is_same_v) - { - b_k_split_offset = k_id * karg.KRead * karg.StrideB; - } - else if constexpr(is_same_v) - { - // KPack * NLane * KLane * K0 * N0 - b_k_split_offset = k_id * karg.KRead * NLane; - } - if(k_id < karg.KBatch - 1) { karg.K = karg.KRead; @@ -712,7 +718,6 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle } index_t a_k_split_offset; - index_t b_k_split_offset; }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -1163,7 +1168,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const index_t k_id, + const index_t Kt) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run( @@ -1176,7 +1183,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + k_id, + Kt); } template (b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment @@ -1597,7 +1608,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle const Problem& problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) + CElementwiseOperation c_element_op, + const index_t k_id, + const index_t Kt) { const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; Run_2Lds( @@ -1611,7 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3_b_preshuffle a_element_op, b_element_op, c_element_op, - block_2_ctile_map); + block_2_ctile_map, + k_id, + Kt); } template (b_grid_desc_bpreshuffled, make_multi_index(n_block_data_idx_on_grid, get_warp_local_1d_id() % NWave, - 0, + k_id, KPackPerGroup * (get_thread_local_1d_id() % WarpSize))); // LDS allocation for A and B: be careful of alignment diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 2918cd33bc..f6189c7495 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -10,6 +10,26 @@ #include #include +#if !defined(CK_TILE_HAS_ROW_NEWBCAST) +// row_newbcast (DPP modifier 0x157) support by architecture: +// - Not supported: gfx908 (MI100) and older +// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures + +#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__) +#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__) +// Explicitly disable for known unsupported architectures +#define CK_TILE_HAS_ROW_NEWBCAST 0 +#else +// Assume support for gfx90a and newer (including all gfx94x and RDNA) +// This is safer as new architectures typically maintain backward compatibility +#define CK_TILE_HAS_ROW_NEWBCAST 1 +#endif +#else +// Conservative default for non-AMD or host compilation +#define CK_TILE_HAS_ROW_NEWBCAST 0 +#endif +#endif + namespace ck_tile { #define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \ @@ -380,18 +400,7 @@ struct MoeSortingKernel row_mask, bank_mask, bound_ctrl))); // row_shr:8 -#if 0 - constexpr int bank_mask_0_7 = 0b1100; - auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; - thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, - __builtin_amdgcn_update_dpp(0, /* old value */ - __builtin_bit_cast(int, thread_data), - 0x157, - row_mask, - bank_mask_0_7, - bound_ctrl))// row_newbcast:7 - ); -#else +#if CK_TILE_HAS_ROW_NEWBCAST data_t xxx =__builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), 0x157, @@ -401,6 +410,17 @@ struct MoeSortingKernel data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; thread_data = thread_data - yyy; +#else + // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute + // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group + int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group + int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address + int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data)); + + // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit) + if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0 + thread_data = thread_data - __builtin_bit_cast(data_t, bcast7); + } #endif } @@ -1267,18 +1287,7 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) row_mask, bank_mask, bound_ctrl))); // row_shr:8 -#if 0 - constexpr int bank_mask_0_7 = 0b1100; - auto reduce_op_r = [&](auto x_, auto y_) { return x_ - y_; }; - thread_data = reduce_op_r(thread_data, __builtin_bit_cast(data_t, - __builtin_amdgcn_update_dpp(0, /* old value */ - __builtin_bit_cast(int, thread_data), - 0x157, - row_mask, - bank_mask_0_7, - bound_ctrl))// row_newbcast:7 - ); -#else +#if CK_TILE_HAS_ROW_NEWBCAST data_t xxx = __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data), @@ -1289,6 +1298,19 @@ CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t& thread_data) data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx; thread_data = thread_data - yyy; +#else + // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute + // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group + int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group + int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address + int bcast7 = + __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data)); + + // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit) + if((__lane_id() / 8) % 2 != 0) + { // Note: != 0, not == 0 + thread_data = thread_data - __builtin_bit_cast(data_t, bcast7); + } #endif } if constexpr(wave_size > 8) diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 33be18948b..ec2d2488c8 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -55,6 +55,7 @@ #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp index 1d2a3e180b..91da3cd27b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async.hpp @@ -25,6 +25,10 @@ struct BaseGemmPipelineAgBgCrCompAsync CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop) { + if(num_loop == 1) + { + return TailNumber::One; + } if(num_loop % PrefetchStages == 1) { return TailNumber::Three; @@ -65,6 +69,11 @@ struct BaseGemmPipelineAgBgCrCompAsync return run_func(bool_constant{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -485,7 +494,7 @@ struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync{}, integral_constant{}); } + else + { + return (run_func(bool_constant{}, + integral_constant{})); + } } // If execution reaches here, it's an invalid tail_number because it wasn't handled above. #if defined(__HIP_DEVICE_COMPILE__) @@ -621,7 +630,7 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } - else + else if(TailNum == TailNumber::Two) { // 2 { @@ -641,6 +650,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4 __builtin_amdgcn_sched_barrier(0); } } + else if(TailNum == TailNumber::One) + { + block_sync_lds(); + block_gemm(c_block_tile, a_block_tile0, b_block_tile0); + __builtin_amdgcn_sched_barrier(0); + } return c_block_tile; } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp new file mode 100644 index 0000000000..9b948626f6 --- /dev/null +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp @@ -0,0 +1,21 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +enum struct GemmPipeline +{ + COMPUTE_ASYNC, + COMPUTE_V3, + COMPUTE_V4, + COMPUTE_V5, + COMPUTE_V6, + MEMORY, + BASIC_V1, + BASIC_V2, + PRESHUFFLE_V2 +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 6dd9eca9ff..8676530d87 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -434,14 +434,13 @@ struct GroupedConvFwdKernelArgs /// multiplication implementation. It is responsible for storing /// results calculated by @ref GemmPipeline_ "GemmPipeline" to /// the output C tensor in global memory. -template struct GroupedConvolutionForwardKernel { - static constexpr bool EnableSplitImage = EnableSplitImage_; + static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage; static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial; static constexpr ConvolutionSpecialization ConvSpecialization = GroupedConvTraitsType_::ConvSpecialization; diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 8695fecac6..703205fd6e 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -63,7 +63,8 @@ template + typename CDElementwise_ = PassThrough, + bool EnableSplitImage_ = false> struct GroupedConvTraits { private: @@ -74,6 +75,7 @@ struct GroupedConvTraits } public: + static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; static constexpr index_t NDimSpatial = NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_; diff --git a/profiler/include/profiler/common.hpp b/profiler/include/profiler/common.hpp new file mode 100644 index 0000000000..2f72e67c6b --- /dev/null +++ b/profiler/include/profiler/common.hpp @@ -0,0 +1,103 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include "ck/utility/data_type.hpp" + +namespace ck { +namespace profiler { + +template +inline __host__ __device__ constexpr double get_rtol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 1.5e-1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +template +inline __host__ __device__ constexpr double get_atol() +{ + if constexpr(std::is_same_v && std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 1e-6; + } + else if constexpr(std::is_same_v) + { + return 1e-3; + } + else if constexpr(std::is_same_v) + { + return 5e-2; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 1e-1; + } + else if constexpr(std::is_same_v) + { + return 16.1; // 240 and 224 are acceptable + } + else if constexpr(std::is_same_v) + { + return 8192.1; // 57344 and 49152 are acceptable + } + else + { + return 1e-3; + } +} + +} // namespace profiler +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 0921b48842..da0dc60760 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -69,19 +69,19 @@ template -bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, - int init_method, - bool do_log, - bool time_kernel, - int M, - int N, - int K, - int StrideA, - int StrideB, - int StrideE, - int n_warmup, - int n_iter, - uint64_t rotating = 0) +bool profile_gemm_blockscale_weightpreshuffle_impl(int do_verification, + int init_method, + bool do_log, + bool time_kernel, + int M, + int N, + int K, + int StrideA, + int StrideB, + int StrideE, + int n_warmup, + int n_iter, + uint64_t rotating = 0) { bool pass = true; @@ -126,6 +126,26 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification, Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a0_m_k, ALayout{}, StrideA); + StrideB = get_stride(b0_k_n, BLayout{}, StrideB); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + int total_gemm_needed = a0_m_k.GetElementSpaceSizeInBytes() + b0_k_n.GetElementSpaceSizeInBytes() + a1_m_k.GetElementSpaceSizeInBytes() + b1_k_n.GetElementSpaceSizeInBytes(); diff --git a/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp b/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp index c76387e2b0..21613e49c6 100644 --- a/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_multiply_multiply_wp_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { @@ -112,6 +113,28 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideD0 = get_stride(d0_m_n, D0Layout{}, StrideD0); + StrideD1 = get_stride(d1_m_n, D1Layout{}, StrideD1); + StrideE = get_stride(e_m_n_host_result, ELayout{}, StrideE); + int total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + d0_m_n.GetElementSpaceSizeInBytes() + d1_m_n.GetElementSpaceSizeInBytes(); @@ -133,7 +156,7 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, case 1: a_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); b_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); - d0_m_n.GenerateTensorValue(GeneratorTensor_2{-5, 5}); + d0_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); d1_m_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); break; default: @@ -282,8 +305,8 @@ bool profile_gemm_multiply_multiply_weight_preshuffle_impl(int do_verification, is_same_v)) { std::string msg = "Error: Incorrect results!"; - double rtol = 1e-3; - double atol = 5e-2; + double rtol = get_rtol(); + double atol = get_atol(); pass = pass & ck::utils::check_err( e_m_n_device_result, e_m_n_host_result, msg, rtol, atol); } diff --git a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp index e537cf2770..5ec056efd1 100644 --- a/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp +++ b/profiler/include/profiler/profile_gemm_universal_preshuffle_impl.hpp @@ -20,6 +20,7 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { @@ -99,6 +100,26 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, Tensor c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); Tensor c_m_n_device_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{})); + // Update strides based on tensor properties if they are <= 0 + auto get_stride = [](auto& tensor, auto layout, ck::index_t current_stride) -> ck::index_t { + if(current_stride <= 0) + { + if constexpr(std::is_same_v) + { + return tensor.GetStrides()[0]; + } + else + { + return tensor.GetStrides()[1]; + } + } + return current_stride; + }; + + StrideA = get_stride(a_m_k, ALayout{}, StrideA); + StrideB = get_stride(b_k_n, BLayout{}, StrideB); + StrideC = get_stride(c_m_n_host_result, CLayout{}, StrideC); + std::size_t total_gemm_needed = a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes(); int rotating_count = std::max( @@ -317,8 +338,8 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification, is_same_v) { std::string msg = "Error: Incorrect results!"; - double rtol = 1e-1; - double atol = 1e-1; + double rtol = get_rtol(); + double atol = get_atol(); pass = pass & ck::utils::check_err( c_m_n_device_result, c_m_n_host_result, msg, rtol, atol); } diff --git a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp index b553e07735..ae12070014 100644 --- a/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp +++ b/profiler/include/profiler/profile_grouped_conv_fwd_outelementop_impl.hpp @@ -5,92 +5,11 @@ #include "ck/library/reference_tensor_operation/cpu/reference_conv_fwd.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/host_tensor_generator.hpp" +#include "profiler/common.hpp" namespace ck { namespace profiler { -template -inline constexpr double get_rtol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 1.5e-1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - -template -inline constexpr double get_atol() -{ - if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 1e-6; - } - else if constexpr(std::is_same_v) - { - return 1e-3; - } - else if constexpr(std::is_same_v) - { - return 5e-2; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 1e-1; - } - else if constexpr(std::is_same_v) - { - return 16.1; // 240 and 224 are acceptable - } - else if constexpr(std::is_same_v) - { - return 8192.1; // 57344 and 49152 are acceptable - } - else - { - return 1e-3; - } -} - template ? N : K; const int DefaultStrideE = ck::is_same_v ? N : M; - bool pass = ck::profiler::profile_gemm_blockscale_weighpreshuffle_impl( + bool pass = ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl( do_verification, init_method, do_log, diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 810ae8d231..d47e55db64 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -245,10 +245,13 @@ add_subdirectory(conv_util) add_subdirectory(reference_conv_fwd) add_subdirectory(gemm) add_subdirectory(gemm_add) +add_subdirectory(gemm_blockscale_wp) add_subdirectory(gemm_layernorm) add_subdirectory(gemm_multi_abd) +add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) +add_subdirectory(gemm_universal_preshuffle) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp index 0820be5b30..1f9033cab9 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -10,11 +10,6 @@ #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#define CK_TILE_PIPELINE_COMPUTE_V3 1 -#define CK_TILE_PIPELINE_MEMORY 2 -#define CK_TILE_PIPELINE_COMPUTE_V4 3 -#define CK_TILE_PIPELINE_COMPUTE_V5 4 - class ArgumentsNotSupportedException : public std::logic_error { public: @@ -56,7 +51,7 @@ struct GemmConfigBase static constexpr ck_tile::index_t TileParitionerGroupNum = 8; static constexpr ck_tile::index_t TileParitionerM01 = 4; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr ck_tile::index_t NumWaveGroups = 1; }; @@ -76,9 +71,9 @@ struct GemmConfigMemoryInterwave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; template @@ -96,8 +91,8 @@ struct GemmConfigMemoryIntrawave : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::MEMORY; }; template @@ -116,8 +111,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -135,8 +130,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; }; template @@ -154,8 +149,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -177,8 +172,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -196,8 +191,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr bool DoubleSmemBuffer = true; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; }; template @@ -216,7 +211,7 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; }; @@ -235,8 +230,8 @@ struct GemmConfigComputeV3_WMMA : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = 16; - static constexpr bool DoubleSmemBuffer = false; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; static constexpr int kBlockPerCu = 2; }; @@ -401,11 +396,11 @@ struct DataTypeTraits static constexpr const char* name = "int8"; }; -template +template struct PipelineTypeTraits; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; @@ -414,7 +409,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; @@ -423,7 +418,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; @@ -432,7 +427,7 @@ struct PipelineTypeTraits }; template <> -struct PipelineTypeTraits +struct PipelineTypeTraits { template using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; diff --git a/test/ck_tile/gemm_tile_engine/CMakeLists.txt b/test/ck_tile/gemm_tile_engine/CMakeLists.txt index 0174028c99..8ad0f2af75 100644 --- a/test/ck_tile/gemm_tile_engine/CMakeLists.txt +++ b/test/ck_tile/gemm_tile_engine/CMakeLists.txt @@ -32,43 +32,35 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti set(target_name "test_gemm_tile_engine_${datatype}_${layout}_${config_name}_${trait}_${tile_config}") set(working_path "${CMAKE_CURRENT_BINARY_DIR}/${datatype}/${layout}/${config_name}") - # Generated header path for this specific kernel configuration + # Generated header path (already created during cmake configuration) set(test_header "${working_path}/gemm_single_${datatype}_${layout}_${trait}_${tile_config}.hpp") + set(test_params_header "${working_path}/test_params.hpp") - # Generate kernel header using tile_engine's Python script - add_custom_command( - OUTPUT ${test_header} - COMMAND ${Python3_EXECUTABLE} ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py - --working_path ${working_path} - --gpu_target "${GEMM_TEST_GPU_TARGETS}" - --datatype ${datatype} - --layout ${layout} - --config_json ${config_json} - --gen_single - --kernel_name "test_gemm_${datatype}_${layout}_${trait}_${tile_config}" - --tile_config "${tile_config}" - --trait_combo "${trait}" - DEPENDS ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py ${config_json} - COMMENT "Generating test header ${test_header}" - VERBATIM - ) + # Verify header exists (should have been generated during cmake configuration) + if(NOT EXISTS ${test_header}) + message(WARNING "Generated header not found: ${test_header}") + return() + endif() + + # Verify test parameters header exists + if(NOT EXISTS ${test_params_header}) + message(WARNING "Test parameters header not found: ${test_params_header}") + return() + endif() + # Create GTest executable for this kernel configuration add_gtest_executable(${target_name} ${CMAKE_CURRENT_SOURCE_DIR}/test_gemm_simple.cpp ) - # Ensure header is generated before compilation - set(header_target "${target_name}_header") - add_custom_target(${header_target} DEPENDS ${test_header}) - add_dependencies(${target_name} ${header_target}) - # Configure GPU architectures for HIP compilation set_property(TARGET ${target_name} PROPERTY HIP_ARCHITECTURES ${GEMM_TEST_GPU_TARGETS}) - # Define preprocessor macros for generated header location + # Define preprocessor macros for generated header location and test parameters target_compile_definitions(${target_name} PRIVATE GEMM_SINGLE_INSTANCE_HPP="${test_header}" + GEMM_TEST_PARAMS_HPP="${test_params_header}" ) # Include directories for headers and dependencies @@ -87,6 +79,11 @@ function(create_individual_gemm_test_target datatype layout config_name trait ti -include ${test_header} # Auto-include generated header ) + # Add FP8 format definitions for proper data type interpretation + if(CK_USE_OCP_FP8) + target_compile_options(${target_name} PRIVATE -DCK_TILE_USE_OCP_FP8) + endif() + message(STATUS " Created test target: ${target_name}") endfunction() @@ -107,7 +104,6 @@ function(build_gemm_test_targets datatype layout config_name) # Locate and validate configuration file set(config_filename "${config_name}.json") set(json_blob "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config_filename}") - message(STATUS " Using test config: ${config_filename}") if(NOT EXISTS ${json_blob}) message(WARNING "Test config file not found: ${json_blob}") @@ -118,7 +114,6 @@ function(build_gemm_test_targets datatype layout config_name) file(MAKE_DIRECTORY ${working_path}) # STEP 1: Discovery phase - list all valid kernel configurations - message(STATUS " Listing kernel configurations for ${datatype}_${layout}...") execute_process( COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py --working_path ${working_path} @@ -134,32 +129,90 @@ function(build_gemm_test_targets datatype layout config_name) ) if(NOT ret EQUAL 0) - message(WARNING "Failed to list kernels for ${datatype}_${layout}: ${list_error}") + message(WARNING "Failed to list kernels for ${datatype}_${layout}_${config_name}: ${list_error}") return() endif() - # Validate kernel discovery results - if(EXISTS ${working_path}/gemm_kernel_count.txt) - file(READ ${working_path}/gemm_kernel_count.txt kernel_count) - string(STRIP "${kernel_count}" kernel_count) - message(STATUS " Found ${kernel_count} test configurations for ${datatype}_${layout}") - else() - message(WARNING "Kernel count file not found for ${datatype}_${layout}") + # Verify kernel list file was generated + if(NOT EXISTS ${working_path}/gemm_kernel_list.txt) + message(STATUS "No kernels found for ${datatype}_${layout}_${config_name} (validation filtered out all combinations)") return() endif() - # STEP 2: Generation phase - create test targets for each discovered kernel - if(EXISTS ${working_path}/gemm_kernel_list.txt) - file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) - set(test_count 0) - foreach(line IN LISTS kernel_lines) - # Parse kernel specification format: kernel_name|tile_config|trait_combo - string(REPLACE "|" ";" parts "${line}") - list(LENGTH parts parts_len) - if(parts_len EQUAL 3) - list(GET parts 0 kernel_name) - list(GET parts 1 tile_config) - list(GET parts 2 trait_combo) + message(STATUS "Building tests for ${datatype}_${layout}_${config_name}") + + # STEP 2a: Extract test parameters from config + set(test_params_file "${working_path}/test_params.hpp") + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${CMAKE_CURRENT_SOURCE_DIR}/extract_test_params.py + --config_file ${json_blob} + --output_file ${test_params_file} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + RESULT_VARIABLE extract_ret + OUTPUT_VARIABLE extract_output + ERROR_VARIABLE extract_error + ) + + if(NOT extract_ret EQUAL 0) + message(WARNING "Failed to extract test parameters for ${datatype}_${layout}: ${extract_error}") + return() + endif() + + # STEP 2b: Header generation phase - generate headers using --gen_single + message(STATUS " Generating headers using --gen_single...") + + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(gen_count 0) + + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) + + # Generate header using --gen_single + execute_process( + COMMAND ${Python3_EXECUTABLE} -u ${TILE_ENGINE_GEMM_DIR}/gemm_instance_builder.py + --working_path ${working_path} + --gpu_target "${GEMM_TEST_GPU_TARGETS}" + --datatype ${datatype} + --layout ${layout} + --config_json ${json_blob} + --gen_single + --kernel_name "${kernel_name}" + --tile_config "${tile_config}" + --trait_combo "${trait_combo}" + WORKING_DIRECTORY ${TILE_ENGINE_GEMM_DIR} + RESULT_VARIABLE gen_ret + OUTPUT_VARIABLE gen_output + ERROR_VARIABLE gen_error + ) + + if(NOT gen_ret EQUAL 0) + message(WARNING "Failed to generate header for ${kernel_name}: ${gen_error}") + else() + math(EXPR gen_count "${gen_count} + 1") + endif() + endif() + endforeach() + + message(STATUS " Generated ${gen_count} headers for ${datatype}_${layout}") + + # STEP 3: Target creation phase - create test targets + message(STATUS " Creating test targets...") + file(STRINGS ${working_path}/gemm_kernel_list.txt kernel_lines) + set(test_count 0) + foreach(line IN LISTS kernel_lines) + # Parse kernel specification format: kernel_name|tile_config|trait_combo + string(REPLACE "|" ";" parts "${line}") + list(LENGTH parts parts_len) + if(parts_len EQUAL 3) + list(GET parts 0 kernel_name) + list(GET parts 1 tile_config) + list(GET parts 2 trait_combo) # Generate test target for this kernel configuration create_individual_gemm_test_target("${datatype}" "${layout}" "${config_name}" "${trait_combo}" "${tile_config}" "${json_blob}") @@ -167,12 +220,7 @@ function(build_gemm_test_targets datatype layout config_name) endif() endforeach() message(STATUS " Created ${test_count} test targets for ${datatype}_${layout}") - else() - message(WARNING "Kernel list file not found for ${datatype}_${layout}") - endif() -endfunction() - -# ============================================================================ +endfunction()# ============================================================================ # MAIN EXECUTION - Test Target Generation # ============================================================================ @@ -198,42 +246,100 @@ endif() message(STATUS "Building GEMM tile engine tests for GPU targets: ${GEMM_TEST_GPU_TARGETS}") -# ============================================================================ -# Test Configuration Matrix -# ============================================================================ + # Enable parallel compilation optimizations + # Set up job pools for better parallel compilation control + set_property(GLOBAL PROPERTY JOB_POOLS + compile_heavy=4 # Limit heavy compilations to prevent OOM + compile_normal=16 # Allow more parallel normal compilations + ) -# Available test configurations (minimal set for fast CI/testing) -set(TEST_CONFIGS - "simple_test_config" - # "medium_tiles_config" # Uncomment for broader testing -) - -# Data types for testing (core precision types) -set(TEST_DATATYPES "fp16" "bf16") -# Extended data type options: -# set(TEST_DATATYPES "fp16" "bf16" "fp32" "fp64" "int8") - -# Matrix layouts for testing (row-column-row is most common) -set(TEST_LAYOUTS "rcr") -# Extended layout options: -# set(TEST_LAYOUTS "rcr" "rrr" "ccr" "crr") + # Enable compiler cache if available and explicitly requested + # Disabled by default due to permission issues in CI environments + option(ENABLE_CCACHE_TESTS "Enable ccache for test compilation" OFF) + if(ENABLE_CCACHE_TESTS) + find_program(CCACHE_PROGRAM ccache) + if(CCACHE_PROGRAM) + set(CMAKE_CXX_COMPILER_LAUNCHER ${CCACHE_PROGRAM}) + message(STATUS "Using ccache for faster test compilation") + else() + message(WARNING "ccache requested but not found") + endif() + else() + message(STATUS "ccache disabled for tests (use -DENABLE_CCACHE_TESTS=ON to enable)") + endif() # ============================================================================ -# Test Target Generation Loop +# Test Configuration Matrix - Clean Focused Design # ============================================================================ -foreach(datatype IN LISTS TEST_DATATYPES) - foreach(layout IN LISTS TEST_LAYOUTS) - foreach(config IN LISTS TEST_CONFIGS) - set(CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${config}.json") - if(EXISTS ${CONFIG_FILE}) - message(STATUS "Building tests for ${datatype}_${layout}_${config}") - build_gemm_test_targets("${datatype}" "${layout}" "${config}") - else() - message(WARNING "Config file not found: ${CONFIG_FILE}") - endif() +# All supported data types and layouts for comprehensive testing +# Note: fp64 not included (no MFMA hardware support) +set(TEST_DATATYPES "fp16;fp8;bf16;fp32") +set(TEST_LAYOUTS "rcr;rrr;ccr;crr") + +# ============================================================================ +# Test Target Generation - Datatype-Specific Categories +# ============================================================================ + +# 1. SMALL DATATYPES: Test optimized config for small data types (fp8, fp16, bf16) +# These data types can use larger warp tiles due to smaller memory footprint +set(SMALL_DATATYPE_CONFIG "small_datatype_config") +set(SMALL_DATATYPE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${SMALL_DATATYPE_CONFIG}.json") +set(SMALL_DATATYPES "fp8;fp16;bf16") + +if(EXISTS ${SMALL_DATATYPE_CONFIG_FILE}) + message(STATUS "Processing small datatype config: ${SMALL_DATATYPE_CONFIG} (fp8, fp16, bf16)") + foreach(datatype IN LISTS SMALL_DATATYPES) + # fp8, fp16, bf16: testing all layouts (rcr, rrr, ccr, crr) + foreach(layout IN LISTS TEST_LAYOUTS) + build_gemm_test_targets("${datatype}" "${layout}" "${SMALL_DATATYPE_CONFIG}") endforeach() endforeach() -endforeach() +else() + message(WARNING "Small datatype config file not found: ${SMALL_DATATYPE_CONFIG_FILE}") +endif() -message(STATUS "GEMM tile engine tests configured for ${TEST_DATATYPES} with ${TEST_LAYOUTS} layouts using ${TEST_CONFIGS} configurations") +# 2. PADDING COVERAGE: Test padding combinations with fixed fp16/rcr configuration +# This focuses on padding behavior (pad_m, pad_n, pad_k) +set(PADDING_CONFIG "padding_coverage_config") +set(PADDING_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${PADDING_CONFIG}.json") + +if(EXISTS ${PADDING_CONFIG_FILE}) + message(STATUS "Processing padding config: ${PADDING_CONFIG} (fp16/rcr only)") + build_gemm_test_targets("fp16" "rcr" "${PADDING_CONFIG}") +else() + message(WARNING "Padding config file not found: ${PADDING_CONFIG_FILE}") +endif() + +# 3. COVERAGE LEVEL: Quick or comprehensive testing +# Quick: ~144 kernels with multiple tile sizes and trait combinations +# Comprehensive: Several thousand kernels with extensive tile sizes, warp configurations, and all trait combinations +set(COVERAGE_LEVEL "quick" CACHE STRING "Coverage level: quick or comprehensive") +set_property(CACHE COVERAGE_LEVEL PROPERTY STRINGS "quick" "comprehensive") + +if(COVERAGE_LEVEL STREQUAL "quick") + set(COVERAGE_CONFIG "quick_coverage_config") + set(COVERAGE_DESC "Quick - approximately 144 kernels with trait combinations") +elseif(COVERAGE_LEVEL STREQUAL "comprehensive") + set(COVERAGE_CONFIG "comprehensive_coverage_config") + set(COVERAGE_DESC "Comprehensive - several thousand kernels with extensive tile and trait coverage") +else() + message(FATAL_ERROR "Invalid COVERAGE_LEVEL: ${COVERAGE_LEVEL}. Must be 'quick' or 'comprehensive'") +endif() + +set(COVERAGE_CONFIG_FILE "${CMAKE_CURRENT_SOURCE_DIR}/configs/${COVERAGE_CONFIG}.json") + +if(EXISTS ${COVERAGE_CONFIG_FILE}) + message(STATUS "Processing coverage config: ${COVERAGE_LEVEL} - ${COVERAGE_DESC}") + build_gemm_test_targets("fp16" "rcr" "${COVERAGE_CONFIG}") +else() + message(WARNING "Coverage config file not found: ${COVERAGE_CONFIG_FILE}") +endif() +# ============================================================================ + + +message(STATUS "GEMM tile engine tests configured with datatype-specific design:") +message(STATUS " - Small datatypes: fp8/fp16/bf16 (all layouts)") +message(STATUS " - Padding coverage with fp16/rcr") +message(STATUS " - Coverage level: ${COVERAGE_LEVEL} (~144 kernels quick, several thousand comprehensive)") +message(STATUS " Use -DCOVERAGE_LEVEL=comprehensive for extensive testing") diff --git a/test/ck_tile/gemm_tile_engine/README.md b/test/ck_tile/gemm_tile_engine/README.md index d99b4115d3..87ce0c9fd0 100644 --- a/test/ck_tile/gemm_tile_engine/README.md +++ b/test/ck_tile/gemm_tile_engine/README.md @@ -17,11 +17,69 @@ JSON Config → tile_engine Python scripts → Generated Headers → Test Execut ``` - **`--list_kernels`**: Get available kernel configurations from JSON +- **`--gen_individual`**: Generate all kernel headers in parallel during CMake configuration - **`--gen_single`**: Generate individual kernel header for each configuration - **Same verification**: Uses tile_engine's adaptive error thresholds and reference calculations - **Same patterns**: Follows tile_engine's tensor initialization, stride calculation, and kernel launching +### Config-Specific Test Parameters +Each test configuration can specify optimized problem sizes in its JSON file: +- **`test_params.problem_sizes`**: Array of `{m, n, k, split_k}` configurations +- **CMake extraction**: `extract_test_params.py` generates config-specific test parameter files +- **Build integration**: Each test target uses parameters appropriate for its kernel configuration +- **Optimized testing**: Different configs test different problem sizes that showcase their strengths The key idea: **Unit tests that use tile_engine's exact kernel generation and verification methodology** instead of creating separate test infrastructure. + +## Test Configurations + +### 1. **Simple Test** (`simple_test_config.json`) +- **Purpose**: Basic functionality validation +- **Config**: 128x128x64, warp 2x2x1, warp_tile 16x16x16 +- **Traits**: compv3 + compv4 pipelines +- **Coverage**: ~2 kernels per datatype/layout + +### 2. **Small Datatype** (`small_datatype_config.json`) +- **Purpose**: Optimized for fp8/fp16/bf16 data types +- **Config**: 128x128x32, warp 2x2x1, warp_tile 32x32x16 +- **Traits**: compv3 pipeline only +- **Coverage**: All 4 layouts (rcr, rrr, ccr, crr) for fp8, fp16, bf16 + +### 3. **Padding Coverage** (`padding_coverage_config.json`) +- **Purpose**: Test padding behavior with all padding flags enabled +- **Config**: Fixed 64x64x32, warp 2x2x1, warp_tile 32x32x16 +- **Padding**: All enabled (pad_m=true, pad_n=true, pad_k=true) +- **Problem sizes**: Vector-aligned but not tile-aligned (104×104×56, 200×152×80, 152×200×64) +- **Coverage**: 1 kernel configuration testing padding with irregular sizes + +### 4. **Coverage Testing** (Quick or Comprehensive) +- **Purpose**: Comprehensive testing across tile sizes, warp configurations, and trait combinations +- **Quick** (`quick_coverage_config.json`): Approximately 144 kernels + - tile_m/n: [32, 64, 256], tile_k: [16, 32] + - warp config: 2×2×1, warp_tile 16×16×16 + - Traits: 3 pipelines × 2 epilogues × 2 schedulers (persistent=false only) + - Focused set testing trait combinations with multiple tile sizes +- **Comprehensive** (`comprehensive_coverage_config.json`): Several thousand kernels + - tile_m/n: [16-256 step 16] + - tile_k: [16, 32, 64] + - warp_m/n: [1, 2, 4], warp_tile_m/n: [16, 32], warp_tile_k: [16, 32] + - Traits: 3 pipelines × 2 epilogues × 2 schedulers × 2 persistent + - Extensive coverage across all tile sizes, warp configurations, and trait combinations + - Exact count varies based on validation filtering +- **Note**: Use CMake option `-DCOVERAGE_LEVEL=comprehensive` to enable comprehensive testing (default is quick) + +## Data Type Support +- ✅ **fp8, fp16, bf16**: Fully supported - all layouts (rcr, rrr, ccr, crr) +- ❌ **fp64**: Not supported (hardware MFMA limitation) +- ⏳ **fp32, bf8, pk-int4-t**: Not yet supported by gemm_instance_builder (will be added later) + +## Test Result Behavior + +Tests automatically handle unsupported configurations through runtime validation: +- **PASSED**: Kernel executed correctly with results within error thresholds ✅ +- **SKIPPED**: Kernel validation returned "Arguments not supported" (expected for certain problem sizes/configurations) ⚠️ +- **FAILED**: Actual error or incorrect computation results ❌ + +When a kernel's `IsSupportedArgument()` check fails (e.g., due to vector alignment requirements, dimension constraints, or padding limitations), the test is automatically skipped rather than failed. This allows comprehensive testing across various problem sizes while gracefully handling configurations that don't meet specific kernel requirements. diff --git a/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json b/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json new file mode 100644 index 0000000000..f2524e4a61 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/comprehensive_coverage_config.json @@ -0,0 +1,37 @@ +{ + "problem": { + "description": "Comprehensive coverage testing - extensive tile size coverage (16-256, step 16) with multiple warp configurations and all trait combinations. Several thousand kernels." + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 256, "split_k": 1}, + {"m": 1024, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 4} + ] + }, + "tile_config": { + "tile_m": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, + "tile_n": {"values": [16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240, 256]}, + "tile_k": {"values": [16, 32, 64]}, + "warp_m": {"values": [1, 2, 4]}, + "warp_n": {"values": [1, 2, 4]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16, 32]}, + "warp_tile_n": {"values": [16, 32]}, + "warp_tile_k": {"values": [8, 16, 32, 64, 128]} + }, + "trait_config": { + "pipeline": {"values": ["mem", "compv3", "compv4"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [true, false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json b/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json new file mode 100644 index 0000000000..e9fcb6fb80 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/large_datatype_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Configuration optimized for large data types (fp32) with smaller warp tiles due to memory constraints" + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 192, "split_k": 1}, + {"m": 256, "n": 384, "k": 192, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [256]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json b/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json new file mode 100644 index 0000000000..33bada839d --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/padding_coverage_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Padding coverage testing - fixed config with fp16/rcr, varying only padding combinations" + }, + "test_params": { + "problem_sizes": [ + {"m": 104, "n": 104, "k": 56, "split_k": 1}, + {"m": 200, "n": 152, "k": 80, "split_k": 1}, + {"m": 152, "n": 200, "k": 64, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [64]}, + "tile_n": {"values": [64]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [true]}, + "pad_n": {"values": [true]}, + "pad_k": {"values": [true]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json b/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json new file mode 100644 index 0000000000..dcc6e99aee --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/quick_coverage_config.json @@ -0,0 +1,34 @@ +{ + "problem": { + "description": "Quick coverage testing - tests multiple tile sizes with all trait combinations (pipelines, epilogues, schedulers). Approximately 144 kernels." + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 2}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 4} + ] + }, + "tile_config": { + "tile_m": {"values": [32, 64, 256]}, + "tile_n": {"values": [32, 64, 256]}, + "tile_k": {"values": [16, 32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["mem", "compv3", "compv4"]}, + "epilogue": {"values": ["default", "cshuffle"]}, + "scheduler": {"values": ["intrawave", "interwave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json index a4f32a1907..498ef9fa33 100644 --- a/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json +++ b/test/ck_tile/gemm_tile_engine/configs/simple_test_config.json @@ -1,88 +1,33 @@ { + "problem": { + "description": "Basic functionality validation with moderate problem sizes" + }, + "test_params": { + "problem_sizes": [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 512, "n": 256, "k": 256, "split_k": 1}, + {"m": 256, "n": 512, "k": 256, "split_k": 1} + ] + }, "tile_config": { - "tile_m": { - "values": [ - 128 - ] - }, - "tile_n": { - "values": [ - 128 - ] - }, - "tile_k": { - "values": [ - 64 - ] - }, - "warp_m": { - "values": [ - 2 - ] - }, - "warp_n": { - "values": [ - 2 - ] - }, - "warp_k": { - "values": [ - 1 - ] - }, - "warp_tile_m": { - "values": [ - 16 - ] - }, - "warp_tile_n": { - "values": [ - 16 - ] - }, - "warp_tile_k": { - "values": [ - 16 - ] - } + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [64]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [16]}, + "warp_tile_n": {"values": [16]}, + "warp_tile_k": {"values": [16]} }, "trait_config": { - "pipeline": { - "values": [ - "compv3", - "compv4" - ] - }, - "scheduler": { - "values": [ - "intrawave" - ] - }, - "epilogue": { - "values": [ - "default" - ] - }, - "pad_m": { - "values": [ - false - ] - }, - "pad_n": { - "values": [ - false - ] - }, - "pad_k": { - "values": [ - false - ] - }, - "persistent": { - "values": [ - false - ] - } + "pipeline": {"values": ["compv3", "compv4"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} }, "k_block_per_cu": 1, "permute_n": false diff --git a/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json b/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json new file mode 100644 index 0000000000..d0d9f99a0c --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/configs/small_datatype_config.json @@ -0,0 +1,35 @@ +{ + "problem": { + "description": "Configuration optimized for small data types (fp8, fp16, bf16) with larger warp tiles" + }, + "test_params": { + "problem_sizes": [ + {"m": 512, "n": 512, "k": 256, "split_k": 1}, + {"m": 1024, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 1024, "k": 512, "split_k": 1}, + {"m": 1024, "n": 1024, "k": 256, "split_k": 1} + ] + }, + "tile_config": { + "tile_m": {"values": [128]}, + "tile_n": {"values": [128]}, + "tile_k": {"values": [32]}, + "warp_m": {"values": [2]}, + "warp_n": {"values": [2]}, + "warp_k": {"values": [1]}, + "warp_tile_m": {"values": [32]}, + "warp_tile_n": {"values": [32]}, + "warp_tile_k": {"values": [16]} + }, + "trait_config": { + "pipeline": {"values": ["compv3"]}, + "epilogue": {"values": ["default"]}, + "scheduler": {"values": ["intrawave"]}, + "pad_m": {"values": [false]}, + "pad_n": {"values": [false]}, + "pad_k": {"values": [false]}, + "persistent": {"values": [false]} + }, + "k_block_per_cu": 1, + "permute_n": false +} diff --git a/test/ck_tile/gemm_tile_engine/extract_test_params.py b/test/ck_tile/gemm_tile_engine/extract_test_params.py new file mode 100644 index 0000000000..c82591e391 --- /dev/null +++ b/test/ck_tile/gemm_tile_engine/extract_test_params.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +import json +import argparse +import os +from pathlib import Path + + +def extract_test_params(config_file, output_file): + """Extract test parameters from config JSON and write to output file""" + + # Read config file + with open(config_file, "r") as f: + config = json.load(f) + + # Extract test parameters + test_params = [] + if "test_params" in config and "problem_sizes" in config["test_params"]: + test_params = config["test_params"]["problem_sizes"] + else: + # Default test parameters if none specified + test_params = [ + {"m": 256, "n": 256, "k": 128, "split_k": 1}, + {"m": 256, "n": 256, "k": 1024, "split_k": 1}, + {"m": 256, "n": 512, "k": 512, "split_k": 1}, + {"m": 512, "n": 256, "k": 512, "split_k": 1}, + ] + + # Write to output file in C++ format + output_dir = Path(output_file).parent + output_dir.mkdir(parents=True, exist_ok=True) + + with open(output_file, "w") as f: + f.write("// Generated test parameters for this configuration\n") + f.write("// This file is auto-generated during CMake configuration\n\n") + f.write("static const std::vector CONFIG_TEST_PARAMS = {\n") + + for i, params in enumerate(test_params): + comma = "," if i < len(test_params) - 1 else "" + f.write( + f" {{{params['m']}, {params['n']}, {params['k']}, {params['split_k']}}}{comma}\n" + ) + + f.write("};\n") + + print( + f"Extracted {len(test_params)} test parameters from {config_file} -> {output_file}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Extract test parameters from config JSON" + ) + parser.add_argument("--config_file", required=True, help="Input config JSON file") + parser.add_argument( + "--output_file", required=True, help="Output test parameters file" + ) + + args = parser.parse_args() + + if not os.path.exists(args.config_file): + print(f"Error: Config file not found: {args.config_file}") + return 1 + + extract_test_params(args.config_file, args.output_file) + return 0 + + +if __name__ == "__main__": + exit(main()) diff --git a/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp b/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp index 439dd4f39b..2054136647 100644 --- a/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp +++ b/test/ck_tile/gemm_tile_engine/test_gemm_simple.cpp @@ -1,8 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -// Unit tests for tile_engine generated GEMM kernels -// Tests kernel correctness using tile_engine's verification methodology +/** + * @file test_gemm_simple.cpp + * @brief Unit tests for GEMM kernels generated by gemm_instance_builder + * + * This test includes kernels generated during CMake configuration by + * gemm_instance_builder.py and tests them with problem sizes extracted + * from the corresponding JSON configuration files. + */ #include #include @@ -68,6 +74,11 @@ struct GemmTestParams int m, n, k, split_k; }; +// Include config-specific test parameters (after GemmTestParams struct is defined) +#ifdef GEMM_TEST_PARAMS_HPP +#include GEMM_TEST_PARAMS_HPP +#endif + class GemmTileEngineTest : public ::testing::TestWithParam { protected: @@ -185,7 +196,16 @@ TEST_P(GemmTileEngineTest, BasicFunctionality) } catch(const std::exception& e) { - FAIL() << "Kernel launch failed: " << e.what(); + std::string error_msg(e.what()); + // If arguments not supported, skip the test (configuration validation failure, not a bug) + if(error_msg.find("Arguments not supported") != std::string::npos) + { + GTEST_SKIP() << "Configuration not supported: " << e.what(); + } + else + { + FAIL() << "Kernel launch failed: " << e.what(); + } } // Copy result back from device @@ -208,13 +228,11 @@ TEST_P(GemmTileEngineTest, KernelInfo) << std::endl; } -// Define test parameters for GEMM verification +// Use config-specific test parameters (included via compile flags) +// CONFIG_TEST_PARAMS is defined in the auto-generated test_params.hpp file INSTANTIATE_TEST_SUITE_P(GemmVerification, GemmTileEngineTest, - ::testing::Values(GemmTestParams{256, 256, 128, 1}, - GemmTestParams{256, 256, 1024, 1}, - GemmTestParams{256, 512, 512, 1}, - GemmTestParams{512, 256, 512, 1}), + ::testing::ValuesIn(CONFIG_TEST_PARAMS), [](const ::testing::TestParamInfo& param_info) { return std::to_string(param_info.param.m) + "x" + std::to_string(param_info.param.n) + "x" + diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt new file mode 100644 index 0000000000..d198db0870 --- /dev/null +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) + endif() +endif() diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp new file mode 100644 index 0000000000..5d88e04690 --- /dev/null +++ b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmBlockScaleWP_FP8_MK_NK : public ck::test::TestGemmBlockscaleWPCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + std::tuple< F8, F32, F8, F32, F8, BF16> +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmBlockScaleWP_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmBlockScaleWP_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_blockscale_wp/test_gemm_common.hpp b/test/gemm_blockscale_wp/test_gemm_common.hpp new file mode 100644 index 0000000000..25ed67a737 --- /dev/null +++ b/test/gemm_blockscale_wp/test_gemm_common.hpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_blockscale_wp_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmBlockscaleWPCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using A0DataType = std::tuple_element_t<2, Tuple>; + using A1DataType = std::tuple_element_t<3, Tuple>; + using B0DataType = std::tuple_element_t<4, Tuple>; + using B1DataType = std::tuple_element_t<5, Tuple>; + using ComputeDataType = std::tuple_element_t<6, Tuple>; + using CDataType = std::tuple_element_t<7, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + static constexpr index_t ScaleBlockM = 1; + static constexpr index_t ScaleBlockN = 128; + static constexpr index_t ScaleBlockK = 128; + + void Run(const int M, const int N, const int K, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v ? K : M; + int StrideB = std::is_same_v ? N : K; + int StrideC = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_blockscale_weightpreshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multiply_multiply_wp/CMakeLists.txt b/test/gemm_multiply_multiply_wp/CMakeLists.txt new file mode 100644 index 0000000000..4302084a6f --- /dev/null +++ b/test/gemm_multiply_multiply_wp/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_multiply_multiply_wp_xdl_fp8 test_gemm_multiply_multiply_wp_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_multiply_multiply_wp_xdl_fp8 PRIVATE utility device_gemm_multiply_multiply_wp_instance) + endif() +endif() diff --git a/test/gemm_multiply_multiply_wp/test_gemm_common.hpp b/test/gemm_multiply_multiply_wp/test_gemm_common.hpp new file mode 100644 index 0000000000..37e2b353e6 --- /dev/null +++ b/test/gemm_multiply_multiply_wp/test_gemm_common.hpp @@ -0,0 +1,93 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_multiply_multiply_wp_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmMultiplyMultiplyWPCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using D0Layout = std::tuple_element_t<2, Tuple>; + using D1Layout = std::tuple_element_t<3, Tuple>; + using ELayout = Row; + using ADataType = std::tuple_element_t<4, Tuple>; + using BDataType = std::tuple_element_t<5, Tuple>; + using ComputeDataType = std::tuple_element_t<6, Tuple>; + using D0DataType = std::tuple_element_t<7, Tuple>; + using D1DataType = std::tuple_element_t<8, Tuple>; + using EDataType = std::tuple_element_t<9, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; // decimal value initialization + static constexpr bool log_ = false; + static constexpr bool bench_ = false; // measure kernel performance + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 4}; } + + void Run(const int M, const int N, const int K) + { + for(size_t i = 0; i < k_batches_.size(); i++) + { + RunSingle(M, N, K, k_batches_[i]); + } + } + + void RunSingle( + const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v, Row> ? K : M; + int StrideB = std::is_same_v, Row> ? N : K; + int StrideD0 = std::is_same_v, Row> ? N : M; + int StrideD1 = std::is_same_v, Row> ? N : M; + int StrideE = std::is_same_v ? N : M; + + all_success = + all_success & + ck::profiler::profile_gemm_multiply_multiply_weight_preshuffle_impl( + verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideD0, + StrideD1, + StrideE, + kbatch, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp new file mode 100644 index 0000000000..bf9b909628 --- /dev/null +++ b/test/gemm_multiply_multiply_wp/test_gemm_multiply_multiply_wp_xdl_fp8.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmMultiplyMultiplyWP_FP8_MK_NK + : public ck::test::TestGemmMultiplyMultiplyWPCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + std::tuple< F8, F8, F8, F32, F32, F16>, + std::tuple< F8, F8, F8, F32, F32, BF16> +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmMultiplyMultiplyWP_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmMultiplyMultiplyWP_FP8_MK_NK, Regular2) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 448; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} diff --git a/test/gemm_universal_preshuffle/CMakeLists.txt b/test/gemm_universal_preshuffle/CMakeLists.txt new file mode 100644 index 0000000000..0d8955f6a4 --- /dev/null +++ b/test/gemm_universal_preshuffle/CMakeLists.txt @@ -0,0 +1,6 @@ +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_universal_preshuffle_xdl_fp8 test_gemm_universal_preshuffle_xdl_fp8.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_universal_preshuffle_xdl_fp8 PRIVATE utility device_gemm_universal_preshuffle_instance) + endif() +endif() diff --git a/test/gemm_universal_preshuffle/test_gemm_common.hpp b/test/gemm_universal_preshuffle/test_gemm_common.hpp new file mode 100644 index 0000000000..367c1a9c7e --- /dev/null +++ b/test/gemm_universal_preshuffle/test_gemm_common.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "gtest/gtest.h" +#include "ck/ck.hpp" +#include "profiler/profile_gemm_universal_preshuffle_impl.hpp" + +namespace ck { +namespace test { + +using Row = ck::tensor_layout::gemm::RowMajor; +using F32 = float; + +template +class TestGemmUniversalPreshuffleCommon : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = Row; + using ADataType = std::tuple_element_t<2, Tuple>; + using BDataType = std::tuple_element_t<3, Tuple>; + using ComputeDataType = std::tuple_element_t<4, Tuple>; + using CDataType = std::tuple_element_t<5, Tuple>; + + public: + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2, 4}; } + + void Run(const int M, const int N, const int K) + { + for(size_t i = 0; i < k_batches_.size(); i++) + { + RunSingle(M, N, K, k_batches_[i]); + } + } + + void RunSingle( + const int M, const int N, const int K, int kbatch = 1, int n_warmup = 1, int n_iter = 10) + { + bool all_success = true; + + int StrideA = std::is_same_v ? K : M; + int StrideB = std::is_same_v ? N : K; + int StrideC = std::is_same_v ? N : M; + + all_success = all_success & + ck::profiler::profile_gemm_universal_preshuffle_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideC, + kbatch, + n_warmup, + n_iter); + + EXPECT_TRUE(all_success); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp new file mode 100644 index 0000000000..06dca026ee --- /dev/null +++ b/test/gemm_universal_preshuffle/test_gemm_universal_preshuffle_xdl_fp8.cpp @@ -0,0 +1,77 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_common.hpp" + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmUniversalPreshuffle_FP8_MK_NK + : public ck::test::TestGemmUniversalPreshuffleCommon< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes_MK_NK = ::testing::Types< +#if defined(CK_ENABLE_FP8) + std::tuple< F8, F8, F8, F16>, + std::tuple< F8, F8, F8, BF16> +#endif + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmUniversalPreshuffle_FP8_MK_NK, KernelTypes_MK_NK); + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular0) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 512; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular1) +{ + std::vector Ms{128, 224, 256, 448, 512}; + constexpr int N = 1024; + constexpr int K = 4096; + + for(int M : Ms) + this->Run(M, N, K); +} + +TYPED_TEST(TestGemmUniversalPreshuffle_FP8_MK_NK, Regular2) +{ + std::vector Ms{128, 256, 512}; + constexpr int N = 448; + constexpr int K = 2048; + + for(int M : Ms) + this->Run(M, N, K); +}