mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK TILE][AICK-439] Fix cshuffle epilogue wave per shuffle (#3364)
* [CK TILE] Fix cshufle epligoue wave per shuffle * Align shuffle per tile with smem * fixes * Fixes for double smem * fix
This commit is contained in:
@@ -35,7 +35,8 @@ template <typename AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeC_ = 1,
|
||||
bool TiledMMAPermuteN_ = false,
|
||||
index_t BlockedXDLN_PerWarp_ = 1> // The number of continuous xdl_output per warp
|
||||
index_t BlockedXDLN_PerWarp_ = 1, // The number of continuous xdl_output per warp
|
||||
bool DoubleSmemBuffer_ = false>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
@@ -59,6 +60,7 @@ struct CShuffleEpilogueProblem
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
static constexpr index_t VectorSizeC = VectorSizeC_;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
|
||||
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
@@ -118,6 +120,7 @@ struct CShuffleEpilogue
|
||||
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
|
||||
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
|
||||
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
|
||||
static constexpr index_t MPerIteration = MPerXdl * MWave;
|
||||
static constexpr index_t NPerIteration = NPerXdl * NWave;
|
||||
@@ -204,6 +207,26 @@ struct CShuffleEpilogue
|
||||
}
|
||||
return max_vector_size / sizeof(DiDataType);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters check and aligment
|
||||
*
|
||||
* @details Return tuple(1, 1) if shuffle_tile values are too large for SMEM.
|
||||
*/
|
||||
template <index_t m_shuffle_tile, index_t n_shuffle_tile>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto AlignShuffleTileWithSmem()
|
||||
{
|
||||
constexpr index_t m_val = MPerXdl * MWave * m_shuffle_tile;
|
||||
constexpr index_t n_val = NPerXdl * NWave * n_shuffle_tile;
|
||||
|
||||
constexpr auto shuffle_tile =
|
||||
m_val * n_val * sizeof(ODataType) > get_smem_capacity() || DoubleSmemBuffer
|
||||
? std::make_tuple(1, 1)
|
||||
: std::make_tuple(m_shuffle_tile, n_shuffle_tile);
|
||||
|
||||
return shuffle_tile;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Shuffle tile configuration parameters
|
||||
*
|
||||
@@ -214,20 +237,23 @@ struct CShuffleEpilogue
|
||||
*/
|
||||
static constexpr auto shuffle_tile_tuple = [] {
|
||||
constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
|
||||
if constexpr(elem_per_thread >= GetVectorSizeC())
|
||||
if constexpr(elem_per_thread <= GetVectorSizeC())
|
||||
{
|
||||
return std::make_tuple(1, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
|
||||
constexpr index_t num_xdl_shuffles = elem_per_thread / GetVectorSizeC();
|
||||
static_assert(elem_per_thread % GetVectorSizeC() == 0);
|
||||
if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
|
||||
(kMPerBlock % num_xdl_shuffles == 0),
|
||||
"kMPerBlock must be divisible by MPerXdl*MWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
|
||||
return AlignShuffleTileWithSmem<min(num_xdl_shuffles,
|
||||
kMPerBlock / (MPerXdl * MWave)),
|
||||
1>();
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -235,7 +261,9 @@ struct CShuffleEpilogue
|
||||
(kNPerBlock % num_xdl_shuffles == 0),
|
||||
"kNPerBlock must be divisible by NPerXdl*NWave and "
|
||||
"num_xdl_shuffles for CShuffleEpilogue");
|
||||
return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
|
||||
return AlignShuffleTileWithSmem<1,
|
||||
min(num_xdl_shuffles,
|
||||
kNPerBlock / (NPerXdl * NWave))>();
|
||||
}
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -232,7 +232,7 @@ struct BatchedGemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr1[GetSmemSize()];
|
||||
__shared__ char smem_ptr1[GemmPipeline::GetSmemSize()];
|
||||
UniversalGemmKernel::RunGemm2LDS({a_ptr},
|
||||
{b_ptr},
|
||||
{/*ds_ptr*/},
|
||||
|
||||
@@ -310,7 +310,7 @@ struct GroupedGemmKernel
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
|
||||
@@ -1084,7 +1084,7 @@ struct UniversalGemmKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
is_any_of<EDataType, fp16_t, bf16_t>::value))
|
||||
@@ -1169,7 +1169,7 @@ struct UniversalGemmKernel
|
||||
// Run the GEMM
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
|
||||
|
||||
@@ -1324,7 +1324,7 @@ struct QuantGemmKernel
|
||||
assert(kargs.k_batch == 1);
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
|
||||
RunGemm2LDS(a_ptr,
|
||||
b_ptr,
|
||||
|
||||
@@ -325,7 +325,7 @@ struct QuantGroupedGemmKernel
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
|
||||
@@ -1048,7 +1048,7 @@ struct GroupedConvolutionBackwardDataKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
is_any_of<OutDataType, fp16_t, bf16_t>::value))
|
||||
|
||||
@@ -1005,7 +1005,7 @@ struct GroupedConvolutionBackwardWeightKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
|
||||
@@ -1184,7 +1184,7 @@ struct GroupedConvolutionForwardKernel
|
||||
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
|
||||
{
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
__shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()];
|
||||
if constexpr(!(EpiloguePipeline::MemoryOperation ==
|
||||
memory_operation_enum::atomic_add &&
|
||||
GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
|
||||
|
||||
Reference in New Issue
Block a user