Gridwise gemm conv v3 force padded layout on gfx950 (#2961)

* Gridwise gemm conv v3 force padded layout on gfx950

* fix bug in other gridwise

* fix

* Update gridwise_gemm_wmma_cshuffle_v3_common.hpp
This commit is contained in:
Bartłomiej Kocot
2025-10-21 15:41:02 +02:00
committed by GitHub
parent 35754d2ec8
commit 3a28632b20
3 changed files with 34 additions and 8 deletions

View File

@@ -45,7 +45,7 @@ template <typename ALayout,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
index_t ABlockLdsExtraMCustom,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
@@ -53,7 +53,7 @@ template <typename ALayout,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t BBlockLdsExtraNCustom,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -272,12 +272,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
#if defined(__gfx950__)
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t ABlockLdsExtraM = 1;
#else
constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
#endif
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM)
{
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
@@ -412,12 +418,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
#if defined(__gfx950__)
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t BBlockLdsExtraN = 1;
#else
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
#endif
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN)
{
return make_naive_tensor_descriptor(
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{

View File

@@ -828,7 +828,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.

View File

@@ -131,7 +131,7 @@ template <typename ALayout,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
index_t ABlockLdsExtraMCustom,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
@@ -139,7 +139,7 @@ template <typename ALayout,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t BBlockLdsExtraNCustom,
index_t CShuffleMXdlPerWavePerShuffle,
index_t CShuffleNXdlPerWavePerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
@@ -698,6 +698,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
#if defined(__gfx950__)
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t ABlockLdsExtraM = 1;
#else
constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
#endif
// A matrix in LDS memory, dst of blockwise copy
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
@@ -705,7 +712,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return make_naive_tensor_descriptor(
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
@@ -840,6 +847,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
#if defined(__gfx950__)
// Force use padded layout on gfx950 to reduce bank conflicts
constexpr index_t BBlockLdsExtraN = 1;
#else
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
#endif
// B matrix in LDS memory, dst of blockwise copy
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{