DeviceGemm_Wmma_CShuffleV3 with BlockGemmPipelineVersion::v3 (#2096)

* Prepare files for DeviceGemm_Wmma_CShuffleV3

* Implement main part of CShuffleV3 with block pipeline v3 for WMMA

* Remove unused functions and template params for A/B descriptors

* Support both gfx11 and gfx12

* Enable SplitK for gfx12 and disable for gfx11

* Added RowColRow layout for DeviceGemmV2 fp16

* Added more instances for Row, Col, Row data layout

* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Row, Row data layout

* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Col, Row data layout

* Added more instances for DeviceGemm_Wmma_CShuffleV3, Row, Row, Row data layout

* Fix formatting

* Add documentation

Based on 5585c3121e

* Enable gemm_universal profiling for gfx11/12

* Add WMMA intrinsics for F8/BF8

* Support F8/BF8 DeviceGemm_Wmma_CShuffleV3, add basic instances

* Add BF16 instances and tests

* Fix test_gemm_universal_wmma_fp8 by adding CK_USE_WMMA_FP8

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>

[ROCm/composable_kernel commit: edd92fc546]
This commit is contained in:
Anton Gorenko
2025-04-28 11:14:21 +06:00
committed by GitHub
parent 790e7bbd4b
commit db016cf6da
44 changed files with 5326 additions and 570 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <initializer_list>
@@ -103,8 +103,10 @@ int profile_gemm_universal(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
using F8 = ck::f8_t;
#endif
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
using I4 = ck::pk_i4_t;
#endif
@@ -201,7 +203,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
@@ -210,6 +212,8 @@ int profile_gemm_universal(int argc, char* argv[])
{
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
}
#endif
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
{
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});