diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 1f6c389090..c96daf3e99 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -104,7 +104,7 @@ enum struct tile_distribution_pattern block_raked, }; -struct TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern { }; @@ -126,7 +126,7 @@ template -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d : public tile_distribution_encoding_pattern { }; @@ -136,12 +136,13 @@ template -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d + : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! @@ -165,7 +166,7 @@ struct TileDistributionEncodingPattern2D -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d + : public tile_distribution_encoding_pattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -244,7 +246,7 @@ struct TileDistributionEncodingPattern2D, @@ -255,7 +257,7 @@ struct TileDistributionEncodingPattern2D>{}); // -> } - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, @@ -273,12 +275,13 @@ template -struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_2d + : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! @@ -295,7 +298,7 @@ struct TileDistributionEncodingPattern2D, @@ -306,7 +309,7 @@ struct TileDistributionEncodingPattern2D>{}); // -> } - CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_shuffled_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, @@ -336,21 +339,21 @@ template -CK_TILE_HOST_DEVICE void print(const TileDistributionEncodingPattern2D&) +CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d&) { - using PatternType = TileDistributionEncodingPattern2D; + using PatternType = tile_distribution_encoding_pattern_2d; - printf("TileDistributionEncodingPattern2D: ", BlockSize, YPerTile, diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp index 3b8d5a142e..9e2a67f940 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_common_policy.hpp @@ -21,12 +21,12 @@ struct BatchedTransposeCommonPolicy constexpr index_t kVectorSize = Problem::VectorSizeInput; static_assert((kLeadDimPerBlock * kVectorSize) % kBlockSize == 0, ""); - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } }; diff --git a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp index e6bbc709ea..137584c3e8 100644 --- a/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp +++ b/include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp @@ -18,12 +18,12 @@ struct BatchedTransposePolicy : public BatchedTransposeCommonPolicy constexpr index_t NPerBlock = Problem::kNPerBlock; constexpr index_t VecLoadSize = Problem::VectorSizeOutput; - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 1d0a4c42f4..7510df091c 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -291,13 +291,14 @@ struct CShuffleEpilogue "Currently, the CShuffle Epilogue only supports the Row Major Output layout"); using TileEncodingPattern = - TileDistributionEncodingPattern2D; - constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution(); + tile_distribution_encoding_pattern_2d; + constexpr auto dram_tile_distribution = + TileEncodingPattern::make_2d_static_tile_distribution(); auto d_dram_windows = generate_tuple( [&](auto idx) { diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 40ee952b1b..8d47ab878e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -168,11 +168,11 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t VecLoadSize = GetVectorSizeB(); - using TileEncodingPattern = TileDistributionEncodingPattern2D; + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; constexpr auto BK0 = number{}; constexpr auto BK1 = number{}; @@ -494,24 +494,24 @@ struct UniversalGemmBasePolicy // Tile: MPerBlock X KPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: KPerBlock X MPerBlock else { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -530,24 +530,24 @@ struct UniversalGemmBasePolicy // Tile: KPerBlock X NPerBlock if constexpr(std::is_same_v) { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } // Tile: NPerBlock X KPerBlock else { - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::Make2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_2d_static_tile_distribution(); } } @@ -562,13 +562,13 @@ struct UniversalGemmBasePolicy constexpr index_t VecLoadSize = GetVectorSizeA(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } template @@ -582,13 +582,13 @@ struct UniversalGemmBasePolicy constexpr index_t VecLoadSize = GetVectorSizeB(); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; - using TileEncodingPattern = TileDistributionEncodingPattern2D; - return TileEncodingPattern::MakeShuffled2DStaticTileDistribution(); + using TileEncodingPattern = tile_distribution_encoding_pattern_2d; + return TileEncodingPattern::make_shuffled_2d_static_tile_distribution(); } template diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp index 52c99f8e99..926f63b5a9 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp @@ -55,44 +55,43 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC static_assert(std::is_same_v); if constexpr(PreshuffleQuant) { - using TileEncodingPattern = - TileDistributionEncodingPatternAQ; + using TileEncodingPattern = tile_distribution_encoding_pattern_aq< + BlockGemmShape, + WarpGemm, + BlockSize, + MPerBlock / WarpGemm::kM, + ck_tile::integer_least_multiple(WarpGemm::kM * KPerBlockAQ, get_warp_size()), + KPerBlockAQ, + VecLoadSize, + PreshuffleQuant>; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); } else { if constexpr(Problem::TransposeC) { using TileEncodingPatternTransposeC = - TileDistributionEncodingPatternAQTransposedC; - return TileEncodingPatternTransposeC::Make2DStaticTileDistribution(); + tile_distribution_encoding_pattern_aq_transposed_c; + return TileEncodingPatternTransposeC::make_2d_static_tile_distribution(); } else { - using TileEncodingPattern = TileDistributionEncodingPatternAQ; + using TileEncodingPattern = tile_distribution_encoding_pattern_aq; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); } } } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 037cef0553..5ce4268dca 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -330,7 +330,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -342,7 +342,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index ff986d86fb..eea8038edf 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -52,14 +52,14 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC Problem::TransposeC>; static_assert(std::is_same_v); - using TileEncodingPattern = TileDistributionEncodingPatternBQ; + using TileEncodingPattern = tile_distribution_encoding_pattern_bq; - return TileEncodingPattern::Make2DStaticTileDistribution(); + return TileEncodingPattern::make_2d_static_tile_distribution(); } template diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 7ce6598b80..8f191f0f94 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -326,7 +326,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(is_a_col_major) { auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(a_shuffle_tmp, a_block_tile); Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); } @@ -338,7 +338,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseBQuantGemmPipelineAgBgCrCompV if constexpr(is_b_row_major) { auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffled2DStaticTileDistribution()); + Policy::template make_shuffled_2d_static_tile_distribution()); transpose_tile2d(b_shuffle_tmp, b_block_tile); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp index 56a906a6bc..54b64c34be 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp @@ -53,7 +53,7 @@ template -struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_aq : public tile_distribution_encoding_pattern { static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); static constexpr index_t warp_size = get_warp_size(); @@ -70,7 +70,7 @@ struct TileDistributionEncodingPatternAQ : public TileDistributionEncodingPatter // KWarps > 1 isn't supported static_assert(KWarps == 1); - CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { if constexpr(PreshuffleQuant) { @@ -119,7 +119,8 @@ template -struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_aq_transposed_c + : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -152,7 +153,7 @@ struct TileDistributionEncodingPatternAQTransposedC : public TileDistributionEnc static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); - CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, @@ -171,7 +172,7 @@ template -struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPattern +struct tile_distribution_encoding_pattern_bq : public tile_distribution_encoding_pattern { // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk! static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!"); @@ -204,7 +205,7 @@ struct TileDistributionEncodingPatternBQ : public TileDistributionEncodingPatter static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover the blocktile along Y."); - CK_TILE_HOST_DEVICE static constexpr auto Make2DStaticTileDistribution() + CK_TILE_HOST_DEVICE static constexpr auto make_2d_static_tile_distribution() { return make_static_tile_distribution( tile_distribution_encoding, diff --git a/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp index 3ff23e2e11..3b1b6ffb6d 100644 --- a/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp +++ b/test/ck_tile/utility/print/test_print_static_encoding_pattern.cpp @@ -32,13 +32,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintThreadRakedPattern) { // Test printing thread raked pattern using PatternType = - TileDistributionEncodingPattern2D<64, 8, 16, 4, tile_distribution_pattern::thread_raked>; + tile_distribution_encoding_pattern_2d<64, + 8, + 16, + 4, + tile_distribution_pattern::thread_raked>; PatternType pattern; std::string output = CapturePrintOutput(pattern); // Verify the output contains expected information - EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos); EXPECT_TRUE(output.find("BlockSize:64") != std::string::npos); EXPECT_TRUE(output.find("YPerTile:8") != std::string::npos); EXPECT_TRUE(output.find("XPerTile:16") != std::string::npos); @@ -52,13 +56,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintWarpRakedPattern) { // Test printing warp raked pattern using PatternType = - TileDistributionEncodingPattern2D<128, 16, 32, 8, tile_distribution_pattern::warp_raked>; + tile_distribution_encoding_pattern_2d<128, + 16, + 32, + 8, + tile_distribution_pattern::warp_raked>; PatternType pattern; std::string output = CapturePrintOutput(pattern); // Verify the output contains expected information - EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos); EXPECT_TRUE(output.find("BlockSize:128") != std::string::npos); EXPECT_TRUE(output.find("YPerTile:16") != std::string::npos); EXPECT_TRUE(output.find("XPerTile:32") != std::string::npos); @@ -72,13 +80,17 @@ TEST_F(PrintStaticEncodingPatternTest, PrintBlockRakedPattern) { // Test printing block raked pattern using PatternType = - TileDistributionEncodingPattern2D<256, 32, 64, 16, tile_distribution_pattern::block_raked>; + tile_distribution_encoding_pattern_2d<256, + 32, + 64, + 16, + tile_distribution_pattern::block_raked>; PatternType pattern; std::string output = CapturePrintOutput(pattern); // Verify the output contains expected information - EXPECT_TRUE(output.find("TileDistributionEncodingPattern2D") != std::string::npos); + EXPECT_TRUE(output.find("tile_distribution_encoding_pattern_2d") != std::string::npos); EXPECT_TRUE(output.find("BlockSize:256") != std::string::npos); EXPECT_TRUE(output.find("YPerTile:32") != std::string::npos); EXPECT_TRUE(output.find("XPerTile:64") != std::string::npos);