mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Solve the Static Encoding Pattern compile error when the tile size is too small (#2079)
This commit is contained in:
@@ -59,6 +59,7 @@
|
||||
#include "ck_tile/core/tensor/transpose_tile.hpp"
|
||||
#include "ck_tile/core/tensor/update_tile.hpp"
|
||||
#include "ck_tile/core/utility/bit_cast.hpp"
|
||||
#include "ck_tile/core/utility/env.hpp"
|
||||
#include "ck_tile/core/utility/functional.hpp"
|
||||
#include "ck_tile/core/utility/functional_with_tuple.hpp"
|
||||
#include "ck_tile/core/utility/ignore.hpp"
|
||||
|
||||
@@ -73,10 +73,11 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
// # of rows in Y dim accessed by single wavefront in one iteration
|
||||
static constexpr index_t Y1 = warp_size / X0;
|
||||
@@ -124,10 +125,11 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
{
|
||||
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
@@ -173,10 +175,11 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
|
||||
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
|
||||
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t X1 = VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t warp_size = get_warp_size();
|
||||
static constexpr index_t num_warps = BlockSize / get_warp_size();
|
||||
static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
|
||||
static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
|
||||
static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
|
||||
static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
|
||||
static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
|
||||
static constexpr index_t Y1 = num_warps;
|
||||
|
||||
@@ -4,9 +4,9 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
Reference in New Issue
Block a user