diff --git a/CMakeLists.txt b/CMakeLists.txt index acae1f5ece..eaed7d3509 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -766,6 +766,9 @@ if(CK_EXPERIMENTAL_BUILDER) ${PROJECT_SOURCE_DIR}/experimental/builder/include/ck_tile/builder DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile ) + + set(CK_TILE_SRC_FOLDER ${CMAKE_SOURCE_DIR}/include/ck_tile/) + rocm_install(DIRECTORY ${CK_TILE_SRC_FOLDER} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/ck_tile) endif() set(CPACK_RESOURCE_FILE_LICENSE "${CMAKE_CURRENT_SOURCE_DIR}/LICENSE") diff --git a/example/26_contraction/run_contraction_bilinear_example.inc b/example/26_contraction/run_contraction_bilinear_example.inc index 78135d6296..69eb42defd 100644 --- a/example/26_contraction/run_contraction_bilinear_example.inc +++ b/example/26_contraction/run_contraction_bilinear_example.inc @@ -233,7 +233,20 @@ int run_contraction_bilinear_example(int argc, char* argv[]) } } - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + if(ck::is_gfx11_supported()) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } return 0; diff --git a/example/26_contraction/run_contraction_scale_example.inc b/example/26_contraction/run_contraction_scale_example.inc index 67f29dbc36..a7451fab71 100644 --- a/example/26_contraction/run_contraction_scale_example.inc +++ b/example/26_contraction/run_contraction_scale_example.inc @@ -216,7 +216,20 @@ int run_contraction_scale_example(int argc, char* argv[]) } } - return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + if(ck::is_gfx11_supported()) + { + return ck::utils::check_err(e_ms_ns_device_result, + e_ms_ns_host_result, + "Error: Incorrect results!", + 1e-4, + 1e-4) + ? 0 + : 1; + } + else + { + return ck::utils::check_err(e_ms_ns_device_result, e_ms_ns_host_result) ? 0 : 1; + } } return 0; diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index 47c47334e7..f79494a478 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -12,40 +12,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - struct GemmConfigBase { static constexpr bool kPadM = false; @@ -122,7 +88,8 @@ struct GemmConfigComputeV3 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -141,7 +108,8 @@ struct GemmConfigComputeV3_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -160,7 +128,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -204,7 +173,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -223,7 +193,8 @@ struct GemmConfigComputeV4_1 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -242,7 +213,8 @@ struct GemmConfigComputeV5 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V5; @@ -282,7 +254,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; @@ -306,7 +279,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index c5a400b4dd..67b411c1f0 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -11,40 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template struct GemmTypeConfig; @@ -111,7 +77,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = false; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V3; @@ -134,7 +101,8 @@ struct GemmConfigComputeV4 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -157,7 +125,8 @@ struct GemmConfigComputeV4_V2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; static constexpr ck_tile::GemmPipeline Pipeline = ck_tile::GemmPipeline::COMPUTE_V4; @@ -178,7 +147,8 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool kPadK = true; @@ -203,7 +173,8 @@ struct GemmConfigPreshufflePrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr int kBlockPerCu = 2; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp index 30a25d83d7..2724834bb5 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.hpp @@ -11,24 +11,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/utility/json_dump.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - struct GemmConfigBase { static constexpr bool kPadM = false; diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp index 0317685770..1fa8a03087 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp @@ -10,40 +10,6 @@ #include "ck_tile/ops/gemm.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template struct GemmTypeConfig; @@ -100,7 +66,8 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -117,7 +84,7 @@ struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 7a4760e1da..37fc998e5b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -24,39 +24,6 @@ inline size_t hash_multiple_strings(const std::vector& inputs) return combined_hash; } -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} -template -constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -} - template static constexpr inline auto is_row_major(Layout layout_) { @@ -124,7 +91,8 @@ struct GemmConfigQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -140,7 +108,8 @@ struct GemmConfigRowColQuant : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template @@ -157,7 +126,7 @@ struct GemmConfigPreshuffleQuantDecode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleQuant = true; }; @@ -176,7 +145,7 @@ struct GemmConfigPreshuffleB_BQuant_Decode : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -206,7 +175,7 @@ struct GemmConfigPreshuffleB_BQuant_Prefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); + ck_tile::get_k_warp_tile(); static constexpr bool PreshuffleB = true; static constexpr bool DoubleSmemBuffer = true; @@ -236,7 +205,8 @@ struct GemmConfigQuantPrefill : public GemmConfigBase static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + static constexpr ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); }; template diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp index 9a9c2235e0..99e7479e36 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_dispatcher.hpp @@ -84,63 +84,46 @@ namespace ck_tile::builder::factory { // CK Tile kernel template -consteval bool IsTileAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && SpecifiesTileTransfer && - SpecifiesTileConvSpecialization && SpecifiesTileBlockGemm && - SpecifiesTileOptimizations; -} +concept IsTileAlgorithm = ConvAlgorithmDescriptor && SpecifiesTileThreadBlock && + SpecifiesTileTransfer && SpecifiesTileConvSpecialization && + SpecifiesTileBlockGemm && SpecifiesTileOptimizations; // XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline) template -consteval bool IsXdlV3Algorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesBlockGemm; -} +concept IsXdlV3Algorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesBlockGemm; // Standard XDL-based kernel (uses XDLops hardware instructions for matrix multiply) template -consteval bool IsXdlAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesNumGroupsToMerge && - SpecifiesLoopScheduler; -} +concept IsXdlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseXdlGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && + SpecifiesNumGroupsToMerge && SpecifiesLoopScheduler; // WMMA-based kernel (uses Wavefront Matrix-Matrix Accumulate instructions) template -consteval bool IsWmmaAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && - SpecifiesBlockTransfer && SpecifiesLdsTransfer && - SpecifiesThreadClusterAccessOrder && SpecifiesSourceAccessOrder && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; -} +concept IsWmmaAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesGridwiseWmmaGemm && + SpecifiesBlockTransfer && SpecifiesLdsTransfer && SpecifiesThreadClusterAccessOrder && + SpecifiesSourceAccessOrder && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesNumPrefetchStages && SpecifiesLoopScheduler; // Specialized DL kernel for specific NHWC/KYXC/NHWK data layouts template -consteval bool IsDlAlgorithm() -{ - return ConvAlgorithmDescriptor && SpecifiesThreadBlock && - SpecifiesFwdConvSpecialization && SpecifiesGemmSpecialization && - SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && - SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; -} +concept IsDlAlgorithm = + ConvAlgorithmDescriptor && SpecifiesThreadBlock && SpecifiesFwdConvSpecialization && + SpecifiesGemmSpecialization && SpecifiesDlThreadConfig && SpecifiesDlThreadCluster && + SpecifiesDlBlockTransfer && SpecifiesDlEpilogue; // XDL-based kernel with large tensor support template -consteval bool IsLargeTensorAlgorithm() -{ - return IsXdlAlgorithm() && SpecifiesLargeTensorSupport; -} +concept IsLargeTensorAlgorithm = + IsXdlAlgorithm && SpecifiesLargeTensorSupport; template ; // CK Tile supports common factory for each direction - if constexpr(IsTileAlgorithm()) + if constexpr(IsTileAlgorithm) { return typename ConvTileFactory::Instance{}; } else if constexpr(ConvDirectionIsForward) { - if constexpr(IsXdlV3Algorithm()) + if constexpr(IsXdlV3Algorithm) { return typename ConvFwdXdlV3Factory::Instance{}; } - else if constexpr(IsXdlAlgorithm()) + else if constexpr(IsXdlAlgorithm) { return typename ConvFwdXdlFactory::Instance{}; } - else if constexpr(IsWmmaAlgorithm()) + else if constexpr(IsWmmaAlgorithm) { return typename ConvFwdWmmaFactory::Instance{}; } - else if constexpr(IsDlAlgorithm()) + else if constexpr(IsDlAlgorithm) { return typename ConvFwdDlFactory::Instance{}; } - else if constexpr(IsLargeTensorAlgorithm()) + else if constexpr(IsLargeTensorAlgorithm) { return typename ConvFwdLargeTensorFactory::Instance{}; } diff --git a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp index 14b8e75668..8db0e5d25d 100644 --- a/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp +++ b/experimental/builder/include/ck_tile/builder/testing/type_traits.hpp @@ -31,13 +31,20 @@ constexpr size_t data_type_sizeof(DataType data_type) { case DataType::UNDEFINED_DATA_TYPE: return 0; case DataType::FP32: return 4; + case DataType::FP32_FP32: return 8; case DataType::FP16: return 2; + case DataType::FP16_FP16: return 4; case DataType::BF16: return 2; + case DataType::BF16_BF16: return 4; case DataType::FP8: return 1; + case DataType::BF8: return 1; + case DataType::FP64: return 8; case DataType::INT32: return 4; case DataType::I8: return 1; + case DataType::I8_I8: return 2; case DataType::U8: return 1; } + return 0; // Default case to ensure all control paths return a value } } // namespace ck_tile::builder::test diff --git a/include/ck/host_utility/device_prop.hpp b/include/ck/host_utility/device_prop.hpp index 8739f65740..43e9350f8f 100644 --- a/include/ck/host_utility/device_prop.hpp +++ b/include/ck/host_utility/device_prop.hpp @@ -72,7 +72,12 @@ inline bool is_xdl_supported() is_gfx12_supported() || is_gfx11_supported(); } -template +template inline bool is_xdl_wmma_supported() { if(ck::get_device_name() == "gfx908" || ck::get_device_name() == "gfx90a" || @@ -82,7 +87,7 @@ inline bool is_xdl_wmma_supported() } else if(is_gfx12_supported() || is_gfx11_supported()) { - if constexpr((MPerXDL != 16) || (NPerXDL != 16)) + if constexpr((MPerXDL32 != 16) || (NPerXDL32 != 16)) { return false; } diff --git a/include/ck/tensor_operation/gpu/device/device_base.hpp b/include/ck/tensor_operation/gpu/device/device_base.hpp index 3e37aac86e..9179a279c5 100644 --- a/include/ck/tensor_operation/gpu/device/device_base.hpp +++ b/include/ck/tensor_operation/gpu/device/device_base.hpp @@ -17,6 +17,7 @@ #endif #endif #include "ck/utility/get_id.hpp" +#include "ck/utility/sequence.hpp" namespace ck { namespace tensor_operation { @@ -96,6 +97,57 @@ static constexpr auto GetNXdlPerWave2() IsWave64>(); \ } +template +static constexpr auto GetWarpTileConfig() +{ + constexpr auto MXdlPerWave64 = MXdlPerWave_; + constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16; + constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16; + + constexpr auto NXdlPerWave = + IsWave64 + ? GetNXdlPerWave2() + : GetNXdlPerWave2(); + + if constexpr(IsWave64 == false && NXdlPerWave != 0) + { + constexpr auto CShuffleNXdlPerWavePerShuffle32 = + NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + ? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16 + : CShuffleNXdlPerWavePerShuffle_; + static_assert(CShuffleNXdlPerWavePerShuffle32 > 0); + return Sequence<16, + 16, + MXdlPerWave32, + NXdlPerWave, + CShuffleMXdlPerWavePerShuffle32, + CShuffleNXdlPerWavePerShuffle32>{}; + } + else + { + return Sequence{}; + } +} + #define INVOKER_RUN_IMPL \ float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \ { \ diff --git a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp index 83dbebb8d6..fff435f1c2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp @@ -166,11 +166,27 @@ struct DeviceContractionMultipleD_Xdl_CShuffle { using DeviceOp = DeviceContractionMultipleD_Xdl_CShuffle; - GET_NXDL_PER_WAVE_IMPL - static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); - static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - - static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto WarpTileConfig64 = GetWarpTileConfig(); + static constexpr auto WarpTileConfig32 = GetWarpTileConfig(); + static constexpr auto NXdlPerWave64 = WarpTileConfig64.At(3); + static constexpr auto NXdlPerWave32 = WarpTileConfig32.At(3); + static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; @@ -321,7 +337,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N({}, {})); // GridwiseGemm - template + template using GridwiseGemmBase = GridwiseGemmMultipleD_xdl_cshuffle< ADataType, // TODO: distinguish A/B datatype BDataType, @@ -340,10 +356,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle KPerBlock, AK1, BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave_, + WarpTileConfig::At(0), + WarpTileConfig::At(1), + WarpTileConfig::At(2), + WarpTileConfig::At(3), ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, @@ -360,13 +376,13 @@ struct DeviceContractionMultipleD_Xdl_CShuffle BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, + WarpTileConfig::At(4), + WarpTileConfig::At(5), CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; // desc for blockwise copy using AGridDesc_AK0_M_AK1 = @@ -588,7 +604,12 @@ struct DeviceContractionMultipleD_Xdl_CShuffle static bool IsSupportedArgument(const Argument& arg) { - if(!ck::is_xdl_wmma_supported()) + if(!ck::is_xdl_wmma_supported()) { return false; } @@ -783,6 +804,8 @@ struct DeviceContractionMultipleD_Xdl_CShuffle << MPerBlock << ", " << NPerBlock << ", " << KPerBlock << ", " + << MPerXDL << ", " + << NPerXDL << ", " << AK1 << ", " << BK1 << ", " << ABlockTransferSrcVectorDim << ", " diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 0e7d1def75..08d555d27c 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -366,6 +366,26 @@ struct amdgcn_compiler_target_state #else static constexpr bool CK_TILE_ARCH_GFX1010 = false; #endif +#if defined(__gfx1011__) + static constexpr bool CK_TILE_ARCH_GFX1011 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1011 = false; +#endif +#if defined(__gfx1012__) + static constexpr bool CK_TILE_ARCH_GFX1012 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1012 = false; +#endif +#if defined(__gfx1013__) + static constexpr bool CK_TILE_ARCH_GFX1013 = true; +#else + static constexpr bool CK_TILE_ARCH_GFX1013 = false; +#endif +#if defined(__gfx10_1_generic__) + static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = true; +#else + static constexpr bool CK_TILE_ARCH_GFX10_1_GENERIC = false; +#endif // __gfx10_1_generic__ #if defined(__gfx1030__) static constexpr bool CK_TILE_ARCH_GFX1030 = true; @@ -504,6 +524,10 @@ CK_TILE_HOST_DEVICE static constexpr uint32_t count_values_of(T search, Ts... se amdgcn_compiler_target_state::CK_TILE_ARCH_GFX942, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX950, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1010, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1011, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1012, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1013, \ + amdgcn_compiler_target_state::CK_TILE_ARCH_GFX10_1_GENERIC, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1030, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1031, \ amdgcn_compiler_target_state::CK_TILE_ARCH_GFX1032, \ diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index a1edce804f..5c99ae8a1c 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -68,7 +68,7 @@ auto shuffle_bq(const ck_tile::HostTensor* t, int block_bq_k) } template -auto shuffle_b(const ck_tile::HostTensor& t) +auto shuffle_b(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); int n_ = t.get_lengths()[1]; @@ -78,10 +78,10 @@ auto shuffle_b(const ck_tile::HostTensor& t) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, + int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, + gemmConfig.N_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); @@ -98,18 +98,24 @@ auto shuffle_b(const ck_tile::HostTensor& t) else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4; } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Warp_Tile, - GemmConfig::N_Warp_Tile, - k_ / GemmConfig::K_Warp_Tile, + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Warp_Tile, + gemmConfig.N_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, divisor, - GemmConfig::K_Warp_Tile / divisor}); + gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } } +template +auto shuffle_b(const ck_tile::HostTensor& t) +{ + return shuffle_b(t, GemmConfig{}); +} + template auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) { @@ -129,22 +135,22 @@ auto bq_permuteN(const ck_tile::HostTensor& t, index_t group_n) } template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmConfig) { assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp; + int n_ = t.get_lengths()[1]; + int k_ = t.get_lengths()[0]; + int NRepeat = gemmConfig.N_Tile / gemmConfig.N_Warp_Tile / gemmConfig.N_Warp; if(ck_tile::is_gfx12_supported()) { constexpr int divisor = 2; constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, + int kABK0PerLane = gemmConfig.K_Warp_Tile / divisor / kABK1PerLane; + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, + gemmConfig.N_Warp, + gemmConfig.N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, kABK0PerLane, divisor, kABK1PerLane}); @@ -161,17 +167,23 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t) else { assert(is_wave32() == false); - divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4; + divisor = gemmConfig.N_Warp_Tile == 32 ? 2 : 4; } - ck_tile::HostTensor t_view({n_ / GemmConfig::N_Tile, - GemmConfig::N_Warp, - GemmConfig::N_Warp_Tile, + ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, + gemmConfig.N_Warp, + gemmConfig.N_Warp_Tile, NRepeat, - k_ / GemmConfig::K_Warp_Tile, + k_ / gemmConfig.K_Warp_Tile, divisor, - GemmConfig::K_Warp_Tile / divisor}); + gemmConfig.K_Warp_Tile / divisor}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); } } + +template +auto shuffle_b_permuteN(const ck_tile::HostTensor& t) +{ + return shuffle_b_permuteN(t, GemmConfig{}); +} } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp index 8029f6a2c7..aa8469be4f 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp @@ -43,4 +43,26 @@ struct TileGemmShape } }; +template +constexpr index_t get_k_warp_tile() +{ +#if CK_TILE_USE_WMMA + return 16; +#else +#if defined(CK_GFX950_SUPPORT) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32; + else + return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64; +#endif +#endif +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 628e9194ae..cb452043d1 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -61,6 +61,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; using BQDataType = remove_cvref_t; + using BLayout = remove_cvref_t; using BQLayout = remove_cvref_t; using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; @@ -156,9 +157,11 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using CDataType = remove_cvref_t; // BDataType gets converted from PkInt4 during loading - using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; - + using OverrideBDataType = std::conditional_t< + std::is_same_v && + std::is_same_v, + ADataType, + BDataType>; using Base = BlockGemmBQuantBase; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 2c191cc2b4..f6ebbd9228 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -33,9 +33,17 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using QuantGroupSize = remove_cvref_t; + using ALayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using CLayout = remove_cvref_t; + // BDataType gets converted from PkInt4 during loading using OverrideBDataType = - std::conditional_t, ADataType, BDataType>; + std::conditional_t && + std::is_same_v, + ADataType, + BDataType>; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; @@ -50,11 +58,6 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; - using ALayout = remove_cvref_t; - using BQLayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; static constexpr index_t BlockSize = Problem::kBlockSize; @@ -184,6 +187,23 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); } + template + CK_TILE_DEVICE void + BGlobalPrefetch(BBlockTile_& b_block_tile, + BDramWindow& b_copy_dram_window, + const BDramTileWindowStep& b_dram_tile_window_step) const + { + if constexpr(!std::is_same_v) + { + LoadAndConvertBTile(b_block_tile, b_copy_dram_window); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + } + else + { + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + } + } + template (ABlockTileDistr{})); using BBlockTile = - decltype(make_static_distributed_tensor(BBlockTileDistr{})); + decltype(make_static_distributed_tensor(BBlockTileDistr{})); using BQBlockTile = decltype(make_static_distributed_tensor(BQBlockTileDistr{})); @@ -289,8 +309,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( + auto b_shuffle_tmp = make_static_distributed_tensor( Policy::template MakeShuffledBRegTileDistribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); @@ -322,8 +341,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using AQuantGrouped = std::integral_constant; -using BQuantGrouped = std::integral_constant; -using RowColQuant = std::integral_constant; -using TensorQuant = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; -using GroupSize64 = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for AQuant tests -// Tuple format: -// clang-format off -using AQuantTypes = ::testing::Types< - // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // RCR layout - with the Prefill BlockTile Config. - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for AQuant -TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); - -// AQuant tests -TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp new file mode 100644 index 0000000000..0e04f9fc9e --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_ccr.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - CCR layout +// Tuple format: +// clang-format off +using AQuantBaseCCRTypes = ::testing::Types< + // CCR layout (ColumnMajor A, ColumnMajor B, RowMajor C with ColumnMajor AQ) - NEW layout support + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Base CCR +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseCCRTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp new file mode 100644 index 0000000000..da32c06304 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rcr.cpp @@ -0,0 +1,42 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - RCR layout base configuration +// Tuple format: +// clang-format off +using AQuantBaseRCRTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = false (RCR layout with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Base RCR +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRCRTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp new file mode 100644 index 0000000000..6e90c44764 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_base_rrr_crr.cpp @@ -0,0 +1,46 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - RRR and CRR layouts +// Tuple format: +// clang-format off +using AQuantBaseRRRCRRTypes = ::testing::Types< + // RRR layout (RowMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // CRR layout (ColumnMajor A, RowMajor B, RowMajor C with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Base RRR/CRR +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantBaseRRRCRRTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp new file mode 100644 index 0000000000..133c11860a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_prefill.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - Prefill Configuration +// Tuple format: +// clang-format off +using AQuantPrefillTypes = ::testing::Types< + // RCR layout - with the Prefill BlockTile Config. + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Prefill +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPrefillTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp new file mode 100644 index 0000000000..35d15f9354 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_preshuffle.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - PreshuffleQuant Configurations +// Tuple format: +// clang-format off +using AQuantPreshuffleTypes = ::testing::Types< + // PreshuffleQuant = true && TransposeC = false (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple, + + // PreshuffleQuant = true && TransposeC = true (with RowMajor AQ - PreshuffleQuant only supports RowMajor) + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant Preshuffle +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantPreshuffleTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp new file mode 100644 index 0000000000..a2a4c2c38b --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_aquant_transpose_c.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for AQuant tests - TransposeC Configuration +// Tuple format: +// clang-format off +using AQuantTransposeCTypes = ::testing::Types< + // PreshuffleQuant = false && TransposeC = true (with RowMajor AQ) + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for AQuant TransposeC +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTransposeCTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp deleted file mode 100644 index 4b1ad068a7..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant.cpp +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using BF16 = ck_tile::bf16_t; -using UInt8 = ck_tile::pk_fp4_raw_t; -using BQuantGrouped = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; -using GroupSize64 = ck_tile::QuantGroupShape>; -using GroupSize32 = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; -using GroupSize2D128N = ck_tile::QuantGroupShape>; - -// Type combinations for BQuant tests (without PreshuffleB) -// Tuple format: -// clang-format off -using BQuantTypes = ::testing::Types< - // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - - // 2d cases with grouping also on the n axis - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // some cases with transpose layouts - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple, - - // pkint4 + transpose cases - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, - std::tuple, - std::tuple, - std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for BQuant (without PreshuffleB) -TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); - -// BQuant tests -TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp new file mode 100644 index 0000000000..d491d89ef4 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 128 +// Tuple format: +// clang-format off +using BQuant1D128Types = ::testing::Types< + // 1d cases with grouping only on k axis + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 128 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D128Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp new file mode 100644 index 0000000000..1019caf1bc --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 1D GroupSize 64 +// Tuple format: +// clang-format off +using BQuant1D64Types = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 1D 64 +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant1D64Types); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp new file mode 100644 index 0000000000..a8b6dcd14b --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_large_n.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 2D Large N (128N) +// Tuple format: +// clang-format off +using BQuant2DLargeNTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 2D Large N +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DLargeNTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp new file mode 100644 index 0000000000..67d52ef874 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_medium_n.cpp @@ -0,0 +1,48 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 2D Medium N (32N and 64N) +// Tuple format: +// clang-format off +using BQuant2DMediumNTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 2D Medium N +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DMediumNTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp new file mode 100644 index 0000000000..865713992d --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_2d_small_n.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - 2D Small N (8N and 16N) +// Tuple format: +// clang-format off +using BQuant2DSmallNTypes = ::testing::Types< + // 2d cases with grouping also on the n axis + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant 2D Small N +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuant2DSmallNTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp deleted file mode 100644 index ae01bddf96..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle.cpp +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include "ck_tile/host.hpp" -#include "ck_tile/ops/gemm.hpp" - -#include -#include - -#include "test_gemm_quant_fixtures.hpp" - -// Type aliases for readability -using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; -using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; -using FP8 = ck_tile::fp8_t; -using BF8 = ck_tile::bf8_t; -using Half = ck_tile::half_t; -using PkInt4 = ck_tile::pk_int4_t; -using BQuantGrouped = std::integral_constant; -using GroupSize = ck_tile::QuantGroupShape>; - -// 2d block sizes for BQuant -using GroupSize2D8N = ck_tile::QuantGroupShape>; -using GroupSize2D16N = ck_tile::QuantGroupShape>; -using GroupSize2D32N = ck_tile::QuantGroupShape>; -using GroupSize2D64N = ck_tile::QuantGroupShape>; - -// Type combinations for BQuant tests with PreshuffleB -// Tuple format: -// clang-format off -using BPreshuffleBQuantTypes = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - // //2d cases with preshuffle B - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - - - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple, - std::tuple ->; -// clang-format on - -// Test suite for BQuant with PreshuffleB -TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleBQuantTypes); - -// BQuant PreshuffleB tests -TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) -{ - this->run_test_with_validation(1024, 1024, 1024); -} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp new file mode 100644 index 0000000000..cf599ebbfd --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_1d.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Decode Config 1D +// Tuple format: +// clang-format off +using BPreshuffleDecode1DTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Decode 1D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode1DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp new file mode 100644 index 0000000000..65ea165b10 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_decode_2d.cpp @@ -0,0 +1,51 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Decode 2D +// Tuple format: +// clang-format off +using BPreshuffleDecode2DTypes = ::testing::Types< + // 2d cases with preshuffle B + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Decode 2D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleDecode2DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp new file mode 100644 index 0000000000..3f6dd225d7 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_1d.cpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Prefill Config 1D +// Tuple format: +// clang-format off +using BPreshufflePrefill1DTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Prefill 1D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill1DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp new file mode 100644 index 0000000000..368204987a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_prefill_2d.cpp @@ -0,0 +1,58 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; + +// 2d block sizes for BQuant +using GroupSize2D8N = ck_tile::QuantGroupShape>; +using GroupSize2D16N = ck_tile::QuantGroupShape>; +using GroupSize2D32N = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - Prefill 2D +// Tuple format: +// clang-format off +using BPreshufflePrefill2DTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle Prefill 2D +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshufflePrefill2DTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp new file mode 100644 index 0000000000..8a05f5812a --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_preshuffle_tiled_permute.cpp @@ -0,0 +1,40 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant Preshuffle tests - TiledPermuteN Config +// Tuple format: +// clang-format off +using BPreshuffleTiledPermuteTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Preshuffle TiledPermuteN +TYPED_TEST_SUITE(TestCkTileGemmPreshuffleBBQuant, BPreshuffleTiledPermuteTypes); + +// BQuant PreshuffleB tests +TYPED_TEST(TestCkTileGemmPreshuffleBBQuant, BQuantPreshuffleTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp new file mode 100644 index 0000000000..230dd8f0fc --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_transpose.cpp @@ -0,0 +1,53 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize64 = ck_tile::QuantGroupShape>; +using GroupSize2D64N = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant tests - Transpose Layouts +// Tuple format: +// clang-format off +using BQuantTransposeTypes = ::testing::Types< + // some cases with transpose layouts + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple, + + // pkint4 + transpose cases + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize64>, + std::tuple, + std::tuple, + std::tuple< RowMajor, RowMajor, RowMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize2D64N>, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant Transpose +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTransposeTypes); + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index 5628b6feae..a7189e7865 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -11,26 +11,6 @@ #include "ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" -template -constexpr ck_tile::index_t get_k_warp_tile_flatmm() -{ -#if CK_TILE_USE_WMMA - return 16; -#else -#if defined(CK_GFX950_SUPPORT) - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 64; - else - return sizeof(PrecType) == 2 ? 32 : 128; -#else - if constexpr(M_Warp_Tile == 32) - return sizeof(PrecType) == 2 ? 16 : 32; - else - return sizeof(PrecType) == 2 ? 32 : 64; -#endif -#endif -} - template class TestCkTileGroupedGemmPreshuffle : public ::testing::Test { @@ -67,7 +47,8 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test static const ck_tile::index_t M_Warp_Tile = 16; static const ck_tile::index_t N_Warp_Tile = 16; - static const ck_tile::index_t K_Warp_Tile = get_k_warp_tile_flatmm(); + static const ck_tile::index_t K_Warp_Tile = + ck_tile::get_k_warp_tile(); static constexpr bool DoubleSmemBuffer = true; // preshuffle v2 uses ping-pong smem static constexpr bool TransposeC = false; // transpose c is not supported @@ -101,46 +82,6 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test return gemm_descs.size() * sizeof(ck_tile::GemmTransKernelArg<>); } - template - auto shuffle_b(const ck_tile::HostTensor& t) - { - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - - if(ck_tile::is_gfx12_supported()) - { - constexpr int divisor = 2; - constexpr int kABK1PerLane = 8; - constexpr int kABK0PerLane = K_Warp_Tile / divisor / kABK1PerLane; - ck_tile::HostTensor t_view({n_ / N_Warp_Tile, - N_Warp_Tile, - k_ / K_Warp_Tile, - kABK0PerLane, - divisor, - kABK1PerLane}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5}); - } - else - { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = N_Warp_Tile == 32 ? 2 : 4; - } - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - } - template void invoke_grouped_gemm(const std::vector& gemm_descs, const ck_tile::stream_config& s, @@ -340,6 +281,14 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test } } + struct BShuffleGemmConfig + { + static constexpr ck_tile::index_t N_Warp_Tile = + TestCkTileGroupedGemmPreshuffle::N_Warp_Tile; + static constexpr ck_tile::index_t K_Warp_Tile = + TestCkTileGroupedGemmPreshuffle::K_Warp_Tile; + }; + public: void Run(const std::vector& Ms, const std::vector& Ns, @@ -424,7 +373,7 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensors[i]); // Host-side preshuffle of B - auto b_shuffle_host = shuffle_b(b_k_n_tensors[i]); + auto b_shuffle_host = ck_tile::shuffle_b(b_k_n_tensors[i]); a_m_k_dev_buf.push_back(std::make_unique( a_m_k_tensors[i].get_element_space_size_in_bytes())); diff --git a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt index 892e123d3d..55f09726cc 100644 --- a/test/ck_tile/grouped_gemm_quant/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_quant/CMakeLists.txt @@ -6,18 +6,18 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() -if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") - # Split into three separate test executables for faster parallel compilation - add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") +# # Split into three separate test executables for faster parallel compilation +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) - add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) - target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) +# add_gtest_executable(test_ck_tile_grouped_gemm_quant_aquant test_grouped_gemm_quant_aquant.cpp) +# target_compile_options(test_ck_tile_grouped_gemm_quant_aquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) # add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp) # target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -endif() +# endif() diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp index bb0b8090fa..8c0c5f78d4 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_common.hpp @@ -81,42 +81,3 @@ inline KernelTraits extract_traits_from_name(const std::string& kernel_name) return traits; } - -template -auto shuffle_b(const ck_tile::HostTensor& t, - ck_tile::index_t N_Warp_Tile, - ck_tile::index_t K_Warp_Tile) -{ - assert(t.get_lengths().size() == 2); - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - int divisor = N_Warp_Tile == 32 ? 2 : 4; - ck_tile::HostTensor t_view( - {n_ / N_Warp_Tile, N_Warp_Tile, k_ / K_Warp_Tile, divisor, K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); -} - -template -auto shuffle_b_permuteN(const ck_tile::HostTensor& t, - ck_tile::index_t N_Warp_Tile, - ck_tile::index_t K_Warp_Tile, - ck_tile::index_t N_Tile, - ck_tile::index_t N_Warp) -{ - assert(t.get_lengths().size() == 2); - - int n_ = t.get_lengths()[1]; - int k_ = t.get_lengths()[0]; - int divisor = N_Warp_Tile == 32 ? 2 : 4; - int NRepeat = N_Tile / N_Warp_Tile / N_Warp; - ck_tile::HostTensor t_view({n_ / N_Tile, - N_Warp, - N_Warp_Tile, - NRepeat, - k_ / K_Warp_Tile, - divisor, - K_Warp_Tile / divisor}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); -} diff --git a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp index 739bd7e677..cad53b472f 100644 --- a/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp +++ b/tile_engine/ops/gemm_preshuffle/gemm_preshuffle_profiler.hpp @@ -111,21 +111,30 @@ class GemmProfiler c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); + struct GemmConfig + { + ck_tile::index_t N_Warp_Tile; + ck_tile::index_t K_Warp_Tile; + ck_tile::index_t N_Tile; + ck_tile::index_t N_Warp; + }; + for(const auto& callable : callables) { - ck_tile::index_t N_Warp_Tile = std::get<1>(config.warp_tile_dims); - ck_tile::index_t K_Warp_Tile = std::get<2>(config.warp_tile_dims); - ck_tile::index_t N_Tile = std::get<1>(config.tile_dims); - ck_tile::index_t N_Warp = std::get<1>(config.warp_dims); + GemmConfig gemmConfig = {}; + gemmConfig.N_Warp_Tile = std::get<1>(config.warp_tile_dims); + gemmConfig.K_Warp_Tile = std::get<2>(config.warp_tile_dims); + gemmConfig.N_Tile = std::get<1>(config.tile_dims); + gemmConfig.N_Warp = std::get<1>(config.warp_dims); ck_tile::HostTensor b_shuffle_host = [&]() { if(config.permuteN) { - return shuffle_b_permuteN(b_k_n, N_Warp_Tile, K_Warp_Tile, N_Tile, N_Warp); + return ck_tile::shuffle_b_permuteN(b_k_n, gemmConfig); } else { - return shuffle_b(b_k_n, N_Warp_Tile, K_Warp_Tile); + return ck_tile::shuffle_b(b_k_n, gemmConfig); } }();