mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
Implement grouped gemm tile loop for RDNA4 (#3304)
* feat: grouped gemm tile loop support for RDNA4 * fix: removed extra parameter from grouped gemm example instance * fix: FP8 check incorrectly enabling FP8 on RDNA3
This commit is contained in:
@@ -334,14 +334,14 @@ struct GridwiseGemm_wmma_cshuffle_v3
|
||||
struct Problem
|
||||
{
|
||||
__host__ Problem() = default;
|
||||
__host__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
__host__ __device__ Problem(index_t M_,
|
||||
index_t N_,
|
||||
index_t K_,
|
||||
std::array<index_t, NumATensor> StrideAs_,
|
||||
std::array<index_t, NumBTensor> StrideBs_,
|
||||
std::array<index_t, NumDTensor> StrideDs_,
|
||||
index_t StrideE_,
|
||||
index_t KBatch_)
|
||||
: M{M_},
|
||||
N{N_},
|
||||
K{K_},
|
||||
|
||||
@@ -351,64 +351,65 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch)
|
||||
// 2D grid (x,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
__host__ __device__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
// Calculate grid size taking into account splitk (KBatch) and multiple groups (Batch)
|
||||
// 3D grid (x,y,z)
|
||||
__host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
__host__ __device__ static auto
|
||||
CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
|
||||
{
|
||||
return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMPadded(index_t M)
|
||||
__host__ __device__ static auto CalculateMPadded(index_t M)
|
||||
{
|
||||
return math::integer_least_multiple(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNPadded(index_t N)
|
||||
__host__ __device__ static auto CalculateNPadded(index_t N)
|
||||
{
|
||||
return math::integer_least_multiple(N, NPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K)
|
||||
{
|
||||
return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
auto K_t = K_Batch * KPerBlock;
|
||||
return (K + K_t - 1) / K_t * KPerBlock;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
__host__ __device__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
|
||||
{
|
||||
constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
|
||||
auto K_t = K_Batch * KReadVec;
|
||||
return (K + K_t - 1) / K_t * KReadVec;
|
||||
}
|
||||
|
||||
__host__ static auto CalculateMBlock(index_t M)
|
||||
__host__ __device__ static auto CalculateMBlock(index_t M)
|
||||
{
|
||||
return math::integer_divide_ceil(M, MPerBlock);
|
||||
}
|
||||
|
||||
__host__ static auto CalculateNBlock(index_t N)
|
||||
__host__ __device__ static auto CalculateNBlock(index_t N)
|
||||
{
|
||||
return math::integer_divide_ceil(N, NPerBlock);
|
||||
}
|
||||
@@ -963,14 +964,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
|
||||
}
|
||||
|
||||
__host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
__host__ __device__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / KPerBlock;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user