From 1b61d3a0ed82d0fd853899bbb9299c534f2aa373 Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Sun, 13 Apr 2025 20:09:30 -0700 Subject: [PATCH] Solve the Static Encoding Pattern compile error when the tile size is too small (#2079) [ROCm/composable_kernel commit: 269f4f6af5aba8c8ac6fe215fcf6ea604dc6b101] --- include/ck_tile/core.hpp | 1 + .../algorithm/static_encoding_pattern.hpp | 27 ++++++++++--------- include/ck_tile/ops/epilogue.hpp | 2 +- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index d9aa8b3551..821b3a8e84 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -59,6 +59,7 @@ #include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/env.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" #include "ck_tile/core/utility/ignore.hpp" diff --git a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp index 78884f3f9f..b56bda3741 100644 --- a/include/ck_tile/core/algorithm/static_encoding_pattern.hpp +++ b/include/ck_tile/core/algorithm/static_encoding_pattern.hpp @@ -73,10 +73,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim // # of rows in Y dim accessed by single wavefront in one iteration static constexpr index_t Y1 = warp_size / X0; @@ -124,10 +125,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); @@ -173,10 +175,11 @@ struct TileDistributionEncodingPattern2D LargestVec ? LargestVec : VecSize; + static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!"); static constexpr index_t Y1 = num_warps; diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 12e53e13e6..6cc0fa8540 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -4,9 +4,9 @@ #pragma once #include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp" +#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" -#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp"