mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Gridwise gemm conv v3 force padded layout on gfx950 (#2961)
* Gridwise gemm conv v3 force padded layout on gfx950 * fix bug in other gridwise * fix * Update gridwise_gemm_wmma_cshuffle_v3_common.hpp
This commit is contained in:
@@ -45,7 +45,7 @@ template <typename ALayout,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t ABlockLdsExtraM,
|
||||
index_t ABlockLdsExtraMCustom,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
@@ -53,7 +53,7 @@ template <typename ALayout,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t BBlockLdsExtraNCustom,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -272,12 +272,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
#if defined(__gfx950__)
|
||||
// Force use padded layout on gfx950 to reduce bank conflicts
|
||||
constexpr index_t ABlockLdsExtraM = 1;
|
||||
#else
|
||||
constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
|
||||
#endif
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(ABlockLdsExtraM)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(AK1Number, Number<KPerBlock + ABlockLdsExtraM>{}, I1));
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
|
||||
}
|
||||
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
|
||||
// in some cases.
|
||||
@@ -412,12 +418,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
#if defined(__gfx950__)
|
||||
// Force use padded layout on gfx950 to reduce bank conflicts
|
||||
constexpr index_t BBlockLdsExtraN = 1;
|
||||
#else
|
||||
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
|
||||
#endif
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(BBlockLdsExtraN)
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(BK0Number, Number<NPerBlock>{}, BK1Number),
|
||||
make_tuple(BK1Number, Number<KPerBlock + BBlockLdsExtraN>{}, I1));
|
||||
make_tuple(Number<NPerBlock + BBlockLdsExtraN>{} * BK1Number, BK1Number, I1));
|
||||
}
|
||||
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
|
||||
{
|
||||
|
||||
@@ -828,7 +828,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
|
||||
// loop to hide it in v4. it may give you some benefit from less valu in compute address
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
|
||||
}
|
||||
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
|
||||
// in some cases.
|
||||
|
||||
@@ -131,7 +131,7 @@ template <typename ALayout,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t ABlockLdsExtraM,
|
||||
index_t ABlockLdsExtraMCustom,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
@@ -139,7 +139,7 @@ template <typename ALayout,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
index_t BBlockLdsExtraN,
|
||||
index_t BBlockLdsExtraNCustom,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
@@ -698,6 +698,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
#if defined(__gfx950__)
|
||||
// Force use padded layout on gfx950 to reduce bank conflicts
|
||||
constexpr index_t ABlockLdsExtraM = 1;
|
||||
#else
|
||||
constexpr index_t ABlockLdsExtraM = ABlockLdsExtraMCustom;
|
||||
#endif
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
@@ -705,7 +712,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
// loop to hide it in v4. it may give you some benefit from less valu in compute address
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(AK0Number, Number<MPerBlock>{}, AK1Number),
|
||||
make_tuple(Number<MPerBlock>{} * AK1Number, AK1Number, I1));
|
||||
make_tuple(Number<MPerBlock + ABlockLdsExtraM>{} * AK1Number, AK1Number, I1));
|
||||
}
|
||||
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
|
||||
// in some cases.
|
||||
@@ -840,6 +847,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
|
||||
constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
|
||||
constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
|
||||
constexpr index_t WaveSize = BlockSize / (MWave * NWave);
|
||||
#if defined(__gfx950__)
|
||||
// Force use padded layout on gfx950 to reduce bank conflicts
|
||||
constexpr index_t BBlockLdsExtraN = 1;
|
||||
#else
|
||||
constexpr index_t BBlockLdsExtraN = BBlockLdsExtraNCustom;
|
||||
#endif
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user