mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_TILE] CK_TILE GEMM WMMA Support for GFX11/GFX12 (#2466)
* WMMA GEMM F16 Implementation Signed-off-by: root <tianyuwu@amd.com> * Self-review Signed-off-by: root <tianyuwu@amd.com> * ASIC check minor tweak Signed-off-by: root <tianyuwu@amd.com> * add missing include file * Set GPU_TARGETS to gfx11/12 generic Signed-off-by: root <tianyuwu@amd.com> * INT8 GFX12 Signed-off-by: root <tianyuwu@amd.com> * add int8x16 branch * Fix CI script Signed-off-by: root <tianyuwu@amd.com> * Fix typo Signed-off-by: root <tianyuwu@amd.com> * Add CK_Tile WMMA example Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> * Fix CI Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> * fix clang format * Set M/N_Warp Back to Constant Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> * Use GemmConfigComputeV3 by default Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12 Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Remove CK_Tile wmma gemm examples from the CI list Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Add atomic add fallback method for gfx11 Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Fix typo Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Omit copyright year Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Support non-square cases Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Fix CI Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Add get_device_ip() Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Revert "Add atomic add fallback method for gfx11" This reverts commit07a79e797d. Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * Revert "Enable CK_TILE_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT for gfx12" This reverts commitceee918007. * Revise method name and typos Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * clang-format Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Try fix CI Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Revert "Try fix CI" This reverts commit7a7241085e. * clang-format Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> * Fix typo caused by merge Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> * Fix typo caused by merging Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> --------- Signed-off-by: root <tianyuwu@amd.com> Signed-off-by: Tianyuan Wu <tianyuwu@amd.com> Signed-off-by: TianyuanWu <Tianyuan.Wu@amd.com> Signed-off-by: Tianyuan Wu <Tianyuan.Wu@amd.com> Co-authored-by: joye <joye@amd.com> Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
This commit is contained in:
@@ -43,7 +43,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
using WarpGemm = WarpGemmDispatcher<
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -78,18 +78,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
WarpGemmDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -115,7 +115,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
using WarpGemm = WarpGemmDispatcher<
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -150,18 +150,18 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
WarpGemmDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false, // SwizzleAccess
|
||||
false, // UseStructuredSparsity
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
|
||||
? WGAttrNumAccessEnum ::Double
|
||||
: WGAttrNumAccessEnum ::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
@@ -187,14 +187,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
|
||||
false>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
|
||||
|
||||
@@ -25,7 +25,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto SwizzleA = false;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -66,7 +66,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
|
||||
|
||||
constexpr auto SwizzleA = false;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::AccDataType,
|
||||
@@ -106,7 +106,7 @@ struct BlockFmhaBwdPipelineTrLoadDefaultPolicy
|
||||
typename BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename BlockFmhaShape::Gemm4WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher< //
|
||||
using WarpGemm = WarpGemmDispatcher< //
|
||||
typename Problem::GemmDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -512,14 +512,13 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
@@ -546,22 +545,22 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
using WarpGemm =
|
||||
WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
|
||||
@@ -956,20 +956,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
{
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{};
|
||||
// return
|
||||
// WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmImpl<WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB<
|
||||
// WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<typename
|
||||
// Problem::PDataType, typename Problem::VDataType>>>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
return WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
return WarpGemmDispatcher<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true>{};
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user