Extend XDL kernel to Support RDNA3/4 - Part 5 (#2725)

* Enable xdl in gfx11 & gfx12

* update cmake file

* fix all instance build (cmake)

* fix batched_gemm_gemm(cmake)

* rebase cmake files

* fix cmake build error

* remve CK_ENABLE_DYNAMIC_WARP_SIZE

* update cmake build error2

* fix gfx11 build

CK_USE_XDL is enabled on gfx11 and gfx12

* fix gfx10 build

* fix gfx11 error

---------

Co-authored-by: Lin, Qun <Quentin.Lin+amdeng@amd.com>
This commit is contained in:
linqunAMD
2025-09-16 01:59:25 +08:00
committed by GitHub
parent 03b59f8c76
commit f22740df82
33 changed files with 243 additions and 397 deletions

View File

@@ -26,14 +26,21 @@ struct AtomicKernelShape
static constexpr index_t Vector_M = Vector::at(number<0>{});
static constexpr index_t Vector_N = Vector::at(number<1>{});
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N;
static constexpr index_t WarpPerBlock_M = MWarps;
static constexpr index_t WarpPerBlock_N = NWarps;
static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t RepeatInWarp =
Warp_M * Warp_N / Vector_M / Vector_N / ck_tile::get_warp_size();
static constexpr index_t RepeatInWarp_M =
(Warp_M / Vector_M > Warp_N / Vector_N) ? RepeatInWarp : 1;
static constexpr index_t RepeatInWarp_N =
(Warp_M / Vector_M > Warp_N / Vector_N) ? 1 : RepeatInWarp;
static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M / RepeatInWarp_M;
static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N / RepeatInWarp_N;
static constexpr index_t Repeat_M = Block_M * RepeatInWarp_M / (WarpPerBlock_M * Warp_M);
static constexpr index_t Repeat_N = Block_N * RepeatInWarp_N / (WarpPerBlock_N * Warp_N);
static constexpr index_t WaveNum = reduce_on_sequence(BlockWaves{}, multiplies{}, number<1>{});
@@ -54,7 +61,10 @@ struct AtomicKernel
using XDataType = typename Problem::XDataType;
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return ck_tile::is_wave32() ? kBlockSize / 2 : kBlockSize;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeTileDistribution()
{