mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
DeviceGemm_Wmma_CShuffleV3 with BlockGemmPipelineVersion::v3 (#2096)
* Prepare files for DeviceGemm_Wmma_CShuffleV3 * Implement main part of CShuffleV3 with block pipeline v3 for WMMA * Remove unused functions and template params for A/B descriptors * Support both gfx11 and gfx12 * Enable SplitK for gfx12 and disable for gfx11 * Added RowColRow layout for DeviceGemmV2 fp16 * Added more instances for Row, Col, Row data layout * Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Row, Row data layout * Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Col, Row data layout * Added more instances for DeviceGemm_Wmma_CShuffleV3, Row, Row, Row data layout * Fix formatting * Add documentation Based on5585c3121e* Enable gemm_universal profiling for gfx11/12 * Add WMMA intrinsics for F8/BF8 * Support F8/BF8 DeviceGemm_Wmma_CShuffleV3, add basic instances * Add BF16 instances and tests * Fix test_gemm_universal_wmma_fp8 by adding CK_USE_WMMA_FP8 --------- Co-authored-by: Anca Hamuraru <anca@streamhpc.com> [ROCm/composable_kernel commit:edd92fc546]
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <initializer_list>
|
||||
@@ -103,8 +103,10 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
using I4 = ck::pk_i4_t;
|
||||
#endif
|
||||
|
||||
@@ -201,7 +203,7 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{});
|
||||
}
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) || defined(CK_USE_WMMA_FP8)
|
||||
else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{});
|
||||
@@ -210,6 +212,8 @@ int profile_gemm_universal(int argc, char* argv[])
|
||||
{
|
||||
return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
#endif
|
||||
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)
|
||||
else if(data_type == GemmDataType::F16_I4_F16 && layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
return profile(F16{}, I4{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{});
|
||||
|
||||
Reference in New Issue
Block a user