[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)

* WMMA GEMM F16 Implementation

Signed-off-by: root <tianyuwu@amd.com>

* Self-review

Signed-off-by: root <tianyuwu@amd.com>

* ASIC check minor tweak

Signed-off-by: root <tianyuwu@amd.com>

* add missing include file

* Set GPU_TARGETS to gfx11/12 generic

Signed-off-by: root <tianyuwu@amd.com>

* INT8 GFX12

Signed-off-by: root <tianyuwu@amd.com>

* add int8x16 branch

* Fix CI script

Signed-off-by: root <tianyuwu@amd.com>

* Fix typo

Signed-off-by: root <tianyuwu@amd.com>

* Add CK_Tile WMMA example

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Fix CI

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* fix clang format

* Set M/N_Warp Back to Constant

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Use GemmConfigComputeV3 by default

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Remove CK_Tile wmma gemm examples from the CI list

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add atomic add fallback method for gfx11

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Omit copyright year

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Support non-square cases

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add get_device_ip()

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Add atomic add fallback method for gfx11"

This reverts commit 07a79e797d.

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"

This reverts commit ceee918007.

* Revise method name and typos

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Try fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Try fix CI"

This reverts commit 7a7241085e.

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo caused by merge

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Fix typo caused by merging

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

---------

Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Tianyuan Wu
2025-08-16 07:22:27 +08:00
committed by GitHub
parent 5ada85ec04
commit 68134b60e4
54 changed files with 1388 additions and 403 deletions

View File

@@ -44,13 +44,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
constexpr bool Preshuffle = Problem::Traits::Preshuffle;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
if constexpr(Preshuffle)
@@ -92,13 +92,13 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
static_assert(Problem::kQuantGroupSize % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
false>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);