[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)

* WMMA GEMM F16 Implementation

Signed-off-by: root <tianyuwu@amd.com>

* Self-review

Signed-off-by: root <tianyuwu@amd.com>

* ASIC check minor tweak

Signed-off-by: root <tianyuwu@amd.com>

* add missing include file

* Set GPU_TARGETS to gfx11/12 generic

Signed-off-by: root <tianyuwu@amd.com>

* INT8 GFX12

Signed-off-by: root <tianyuwu@amd.com>

* add int8x16 branch

* Fix CI script

Signed-off-by: root <tianyuwu@amd.com>

* Fix typo

Signed-off-by: root <tianyuwu@amd.com>

* Add CK_Tile WMMA example

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Fix CI

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* fix clang format

* Set M/N_Warp Back to Constant

Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>

* Use GemmConfigComputeV3 by default

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Remove CK_Tile wmma gemm examples from the CI list

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add atomic add fallback method for gfx11

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Omit copyright year

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Support non-square cases

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Add get_device_ip()

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Add atomic add fallback method for gfx11"

This reverts commit 07a79e797d.

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12"

This reverts commit ceee918007.

* Revise method name and typos

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Try fix CI

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Revert "Try fix CI"

This reverts commit 7a7241085e.

* clang-format

Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>

* Fix typo caused by merge

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

* Fix typo caused by merging

Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>

---------

Signed-off-by: root <tianyuwu@amd.com>
Signed-off-by: Tianyuan Wu <tianyuwu@amd.com>
Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com>
Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com>
Co-authored-by: joye <joye@amd.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
Tianyuan Wu
2025-08-16 07:22:27 +08:00
committed by GitHub
parent 5ada85ec04
commit 68134b60e4
54 changed files with 1388 additions and 403 deletions

View File

@@ -43,7 +43,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
using WarpGemm = WarpGemmDispatcher<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
@@ -78,18 +78,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
@@ -115,7 +115,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
using WarpGemm = WarpGemmDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
@@ -150,18 +150,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true,
false, // SwizzleAccess
false, // UseStructuredSparsity
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
? WGAttrNumAccessEnum ::Double
: WGAttrNumAccessEnum ::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
@@ -187,14 +187,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,

View File

@@ -25,7 +25,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr auto SwizzleA = false;
using WarpGemm = WarpGemmMfmaDispatcher< //
using WarpGemm = WarpGemmDispatcher< //
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
@@ -66,7 +66,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
constexpr auto SwizzleA = false;
using WarpGemm = WarpGemmMfmaDispatcher< //
using WarpGemm = WarpGemmDispatcher< //
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
@@ -106,7 +106,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
typename BlockFmhaShape::Gemm4BlockWarps,
typename BlockFmhaShape::Gemm4WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher< //
using WarpGemm = WarpGemmDispatcher< //
typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -512,14 +512,13 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>;
using WarpGemm = WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
@@ -546,22 +545,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single>;
using WarpGemm =
WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true,
false,
false,
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
? WGAttrNumAccessEnum::Double
: WGAttrNumAccessEnum::Single>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,

View File

@@ -956,20 +956,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
// return
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
// Problem::PDataType, typename Problem::VDataType>>>{};
}
else
{
return WarpGemmMfmaDispatcher<
typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>{};
return WarpGemmDispatcher<typename Problem::PDataType,
typename Problem::VDataType,
typename Problem::OaccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>{};
}
}();