diff --git a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt index be83be6ee5..02863398fe 100644 --- a/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt +++ b/example/ck_tile/99_toy_example/02_gemm/CMakeLists.txt @@ -13,6 +13,11 @@ set(EXAMPLE_REDUCE_COMPILE_OPTIONS) # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations list(APPEND EXAMPLE_REDUCE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +if(DEFINED kernel) + message("Compiling with Kernel: ${kernel}") + target_compile_definitions(${EXAMPLE_REDUCE} PRIVATE KERNEL_${kernel}=1) +endif() + target_compile_options(${EXAMPLE_REDUCE} PRIVATE ${EXAMPLE_REDUCE_COMPILE_OPTIONS}) # TODO: we have to turn off this global prop, otherwise the progress bar generated diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp index 3894b6cd27..755e2daf85 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_asmem_bsmem_creg_default_policy.hpp @@ -6,10 +6,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp" -#define mfma_m32_n32_k8 0 -#define mfma_m32_n32_k16 0 -#define mfma_m16_n16_k16 0 -#define mfma_m16_n16_k32 1 +#include "config.h" namespace ck_tile { @@ -20,8 +17,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { -#if mfma_m32_n32_k8 -#pragma message ("mfma m32 n32 k8") +#if defined(USING_MFMA_32x32x_8x2) +#pragma message ("mfma m32 n32 k16") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -34,8 +31,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1); } -#elif mfma_m32_n32_k16 -#pragma message ("mfma m32 n32 k16") +#elif defined(NAIVE_IMPLEMENTATION) +#pragma message ("mfma m32 n32 k8") if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -48,7 +45,8 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, 4, 1); } -#elif mfma_m16_n16_k16 + +#elif defined(USING_MFMA_16x16x16) #pragma message("mfma m16 n16 k16") if constexpr(std::is_same_v && std::is_same_v && @@ -62,7 +60,7 @@ struct BlockGemmASmemBSmemCRegDefaultPolicy { return make_tuple(WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, 4, 1); } -#elif mfma_m16_n16_k32 +#elif defined(USING_MFMA_16x16x_16x2) #pragma message("mfma m16 n16 k32") if constexpr(std::is_same_v && std::is_same_v && diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp index 67a7985b48..1a729d8e58 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -114,7 +114,7 @@ struct BlockGemmPipelineAGmemBGmemCReg // Acc register tile auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; -#if 1 +#if defined(ENABLE_PREFETCH) #pragma message ("prefetch") // prefetch // global read 0 diff --git a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp index 6fdb65e26b..b8faa00d85 100644 --- a/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg_default_policy.hpp @@ -7,10 +7,7 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "block_gemm_asmem_bsmem_creg.hpp" -#define BANK_CONFLICT_K_FIRST 0 -#define PADDING_K_FIRST 0 -#define PADDING_MN_FIRST 0 -#define XOR 1 +#include "config.h" namespace ck_tile { @@ -26,7 +23,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = 8; -#if BANK_CONFLICT_K_FIRST +#if defined(NAIVE_IMPLEMENTATION) #pragma message ("BANK_CONFLICT: K_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -41,7 +38,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); -#elif PADDING_K_FIRST +#elif defined(PADDING_K_FIRST) #pragma message ("BANK_CONFLICT: PADDING_K_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -56,7 +53,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); -#elif PADDING_MN_FIRST +#elif defined(PADDING_MN_FIRST) #pragma message ("BANK_CONFLICT: PADDING_MN_FIRST") constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -71,7 +68,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); -#elif XOR +#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE) #pragma message ("BANK_CONFLICT: XOR") using ADataType = remove_cvref_t; @@ -125,7 +122,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = 8; -#if BANK_CONFLICT_K_FIRST +#if defined(NAIVE_IMPLEMENTATION) #pragma message ("BANK_CONFLICT: K_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -140,7 +137,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); -#elif PADDING_K_FIRST +#elif defined(PADDING_K_FIRST) #pragma message ("BANK_CONFLICT: PADDING_K_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -155,7 +152,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); -#elif PADDING_MN_FIRST +#elif defined(PADDING_MN_FIRST) #pragma message ("BANK_CONFLICT: PADDING_MN_FIRST") constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( make_tuple(number{}, number{}, number{}), @@ -170,7 +167,7 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); -#elif XOR +#elif defined(USING_XOR_BASED_BANK_CONFLICT_FREE) #pragma message ("BANK_CONFLICT: XOR") using BDataType = remove_cvref_t; diff --git a/example/ck_tile/99_toy_example/02_gemm/config.h b/example/ck_tile/99_toy_example/02_gemm/config.h new file mode 100644 index 0000000000..dfff4a6e12 --- /dev/null +++ b/example/ck_tile/99_toy_example/02_gemm/config.h @@ -0,0 +1,31 @@ + +#if defined(KERNEL_A) + #define USING_MFMA_32x32x_8x2 +#elif defined(KERNEL_B) + #define USING_MFMA_16x16x16 +#elif defined(KERNEL_C) + #define USING_MFMA_16x16x_16x2 +#elif defined(KERNEL_D) + #define USING_MFMA_16x16x_16x2 + #define USING_XOR_BASED_BANK_CONFLICT_FREE +#elif defined(KERNEL_E) + #define USING_MFMA_16x16x_16x2 + #define USING_XOR_BASED_BANK_CONFLICT_FREE + #define ADJUST_BLOCK_TILE_SHAPE +#elif defined(KERNEL_F) + #define USING_MFMA_16x16x_16x2 + #define USING_XOR_BASED_BANK_CONFLICT_FREE + #define ADJUST_BLOCK_TILE_SHAPE + #define ENABLE_PREFETCH + #define ENABLE_INSTRUCTION_SCH +#elif defined(KERNEL_G) + #define USING_MFMA_16x16x_16x2 + #define USING_XOR_BASED_BANK_CONFLICT_FREE + #define ADJUST_BLOCK_TILE_SHAPE + #define ENABLE_PREFETCH + #define ENABLE_INSTRUCTION_SCH + #define ENABLE_CACHE_AWARE_WG_SCH +#else + #define NAIVE_IMPLEMENTATION +#endif + diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp index 320b94c481..6fb5c58ed8 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.cpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.cpp @@ -3,6 +3,7 @@ #include "ck_tile/host.hpp" #include "reference_gemm.hpp" +#include "config.h" #include "gemm.hpp" /* @@ -47,6 +48,43 @@ int main(int argc, char* argv[]) K = std::stoi(argv[4]); } +#if defined(KERNEL_A) + printf("*** KernelA test *** \n"); + printf(" --> Using mfma_16x16x(8x2)\n"); +#elif defined(KERNEL_B) + printf("*** KernelB test *** \n"); + printf(" --> Using mfma_16x16x16\n"); +#elif defined(KERNEL_C) + printf("*** KernelC test *** \n"); + printf(" --> Using mfma_16x16x(16x2)\n"); +#elif defined(KERNEL_D) + printf("*** KernelD test *** \n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based banck conflict-free\n"); +#elif defined(KERNEL_E) + printf("*** KernelE test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based banck conflict-free\n"); + printf(" --> Adjust block tile shape\n"); +#elif defined(KERNEL_F) + printf("*** KernelF test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based banck conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); + printf(" --> Enable instruction schedule\n"); +#elif defined(KERNEL_G) + printf("*** KernelG test ***\n"); + printf(" --> Using mfma_16x16x(16x2)\n"); + printf(" --> XOR-based banck conflict-free\n"); + printf(" --> Adjust block tile shape\n"); + printf(" --> Enable prefetch\n"); + printf(" --> Enable instruction schedule\n"); + printf(" --> Enable cache-aware thread blocks schedule\n"); +#else + printf("*** Naive implementation test ***\n"); +#endif + const ck_tile::index_t Lda = K; const ck_tile::index_t Ldb = K; const ck_tile::index_t Ldc = N; @@ -82,7 +120,7 @@ int main(int argc, char* argv[]) constexpr ck_tile::index_t kBlockSize = 256; -#if 1 +#ifdef ADJUST_BLOCK_TILE_SHAPE #pragma message ("(Increase KperBlock, reduce MperBlock) -> increase Grid size") constexpr ck_tile::index_t kGemmMPerBlock = 128; constexpr ck_tile::index_t kGemmKPerBlock = 64; diff --git a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp index a8002701ab..c6d2caf59c 100644 --- a/example/ck_tile/99_toy_example/02_gemm/gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/gemm.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/tensor/tile_distribution.hpp" #include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "config.h" #include "grid_gemm.hpp" namespace ck_tile { @@ -28,7 +29,7 @@ struct GridGemmProblem using CElementFunction = CElementFunction_; }; -#ifndef INSTRUCTION_SCHEDULE +#ifndef ENABLE_INSTRUCTION_SCH template struct TileGemmShape { @@ -85,7 +86,7 @@ struct Gemm CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) { -#if 1 +#if defined(ENABLE_CACHE_AWARE_WG_SCH) #pragma message ("Cache-aware work group sch") return [=](index_t block_1d_id) { constexpr index_t M01 = 4; @@ -147,7 +148,7 @@ struct Gemm #endif } -#ifndef INSTRUCTION_SCHEDULE +#ifndef ENABLE_INSTRUCTION_SCH template CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() { diff --git a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp index 8b6b845577..b66a66626e 100644 --- a/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp +++ b/example/ck_tile/99_toy_example/02_gemm/grid_gemm.hpp @@ -1,9 +1,8 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. -#define INSTRUCTION_SCHEDULE -#ifdef INSTRUCTION_SCHEDULE +#ifdef ENABLE_INSTRUCTION_SCH #include "instruction_schedule/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "instruction_schedule/gemm_pipeline_problem.hpp" #include "instruction_schedule/gemm_universal_pipeline_ag_bg_cr_policy.hpp" @@ -59,7 +58,7 @@ struct GridGemm auto b_block_window = make_tile_window( b_grid, make_tuple(number{}, number{}), {iN, 0}); -#ifndef INSTRUCTION_SCHEDULE +#ifndef ENABLE_INSTRUCTION_SCH #pragma message ("disable instruction scheduling") // Block GEMM pipeline w/o instruction scheduling constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline();