mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
Extend XDL kernel to Support RDNA3/4 - Part 5 (#2725)
* Enable xdl in gfx11 & gfx12 * update cmake file * fix all instance build (cmake) * fix batched_gemm_gemm(cmake) * rebase cmake files * fix cmake build error * remve CK_ENABLE_DYNAMIC_WARP_SIZE * update cmake build error2 * fix gfx11 build CK_USE_XDL is enabled on gfx11 and gfx12 * fix gfx10 build * fix gfx11 error --------- Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
This commit is contained in:
@@ -26,14 +26,21 @@ struct AtomicKernelShape
|
||||
static constexpr index_t Vector_M = Vector::at(number<0>{});
|
||||
static constexpr index_t Vector_N = Vector::at(number<1>{});
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
|
||||
|
||||
static constexpr index_t WarpPerBlock_M = MWarps;
|
||||
static constexpr index_t WarpPerBlock_N = NWarps;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
|
||||
static constexpr index_t RepeatInWarp =
|
||||
Warp_M * Warp_N / Vector_M / Vector_N / ck_tile::get_warp_size();
|
||||
static constexpr index_t RepeatInWarp_M =
|
||||
(Warp_M / Vector_M > Warp_N / Vector_N) ? RepeatInWarp : 1;
|
||||
static constexpr index_t RepeatInWarp_N =
|
||||
(Warp_M / Vector_M > Warp_N / Vector_N) ? 1 : RepeatInWarp;
|
||||
|
||||
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M / RepeatInWarp_M;
|
||||
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N / RepeatInWarp_N;
|
||||
|
||||
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
|
||||
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
|
||||
|
||||
static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{});
|
||||
|
||||
@@ -54,7 +61,10 @@ struct AtomicKernel
|
||||
using XDataType = typename Problem::XDataType;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize()
|
||||
{
|
||||
return ck_tile::is_wave32() ? kBlockSize / 2 : kBlockSize;
|
||||
}
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user