add up the padding algorithm

This commit is contained in:
ThomasNing
2026-01-28 23:15:21 -06:00
parent 5b6cbd329c
commit 24baa5245f

View File

@@ -298,22 +298,33 @@ struct CShuffleEpilogue
{
constexpr auto DataTypeSize = sizeof(ODataType);
constexpr index_t VectorLen = GetVectorSizeC();
constexpr index_t Banks = get_n_lds_banks();
constexpr index_t banks = get_n_lds_banks();
// calculate how many elements to pad to avoid bank conflict
#if defined(__gfx950__)
constexpr auto PaddingAmount = VectorLen;
#else
constexpr auto PaddingAmount = 0;
#endif
constexpr index_t BytesPerBank = 4;
// N is contiguous dimension
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
{
constexpr index_t MLdsLayerRequired =
Banks * BytesPerBank / NPerIterationShuffle / DataTypeSize;
banks * BytesPerBank / NPerIterationShuffle / DataTypeSize;
constexpr auto MLdsLayer = max(1, MLdsLayerRequired);
constexpr index_t BaseStrideElems = NPerIterationShuffle * MLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
// calculate how many elements to pad to avoid bank conflict
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<MPerIterationShuffle / MLdsLayer>{},
number<NPerIterationShuffle / VectorLen * MLdsLayer>{},
@@ -351,6 +362,23 @@ struct CShuffleEpilogue
get_n_lds_banks() * BytesPerBank / MPerIterationShuffle / DataTypeSize;
constexpr auto NLdsLayer = max(1, NLdsLayerRequired);
constexpr index_t BaseStrideElems = MPerIterationShuffle * NLdsLayer;
static_assert((BaseStrideElems * DataTypeSize) % BytesPerBank == 0,
"LDS row stride must be 4B-aligned for bank-word padding logic");
#if defined(__gfx950__)
constexpr index_t ElemsPer4B = BytesPerBank / ck_tile::gcd(BytesPerBank, DataTypeSize);
constexpr auto ToWords = [](index_t elems) constexpr {
return (elems * DataTypeSize) / BytesPerBank;
};
constexpr index_t BaseWords = ToWords(BaseStrideElems);
constexpr index_t PadWords = ((BaseWords % 2) == 0) ? 1 : 0;
constexpr auto PaddingAmount = PadWords * ElemsPer4B;
#else
constexpr auto PaddingAmount = 0;
#endif
constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NPerIterationShuffle / NLdsLayer>{},
number<MPerIterationShuffle / VectorLen * NLdsLayer>{},
@@ -768,7 +796,7 @@ struct CShuffleEpilogue
MPerIterationShuffle,
NPerIterationShuffle,
GetVectorSizeC(),
tile_distribution_pattern::warp_raked,
tile_distribution_pattern::thread_raked,
Problem::kNumWaveGroups>;
constexpr auto dram_tile_distribution =
TileEncodingPattern::make_2d_static_tile_distribution();