DeviceGemm_Wmma_CShuffleV3 with BlockGemmPipelineVersion::v3 (#2096)

* Prepare files for DeviceGemm_Wmma_CShuffleV3

* Implement main part of CShuffleV3 with block pipeline v3 for WMMA

* Remove unused functions and template params for A/B descriptors

* Support both gfx11 and gfx12

* Enable SplitK for gfx12 and disable for gfx11

* Added RowColRow layout for DeviceGemmV2 fp16

* Added more instances for Row, Col, Row data layout

* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Row, Row data layout

* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Col, Row data layout

* Added more instances for DeviceGemm_Wmma_CShuffleV3, Row, Row, Row data layout

* Fix formatting

* Add documentation

Based on e5ad48a784

* Enable gemm_universal profiling for gfx11/12

* Add WMMA intrinsics for F8/BF8

* Support F8/BF8 DeviceGemm_Wmma_CShuffleV3, add basic instances

* Add BF16 instances and tests

* Fix test_gemm_universal_wmma_fp8 by adding CK_USE_WMMA_FP8

---------

Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
This commit is contained in:
Anton Gorenko
2025-04-28 11:14:21 +06:00
committed by GitHub
parent 8add2cf45d
commit edd92fc546
44 changed files with 5326 additions and 570 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -22,6 +22,10 @@ enum struct WmmaInstr
wmma_f32_16x16x16_f16_gfx12,
wmma_f32_16x16x16_bf16_gfx12,
wmma_i32_16x16x16_iu8_gfx12,
wmma_f32_16x16x16_f8f8_gfx12,
wmma_f32_16x16x16_f8bf8_gfx12,
wmma_f32_16x16x16_bf8f8_gfx12,
wmma_f32_16x16x16_bf8bf8_gfx12,
};
/*
@@ -400,6 +404,146 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
#ifdef __gfx12__
intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
#else
ignore = a;
ignore = b;
ignore = reg_c;
#endif
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
#ifdef __gfx12__
intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
#else
ignore = a;
ignore = b;
ignore = reg_c;
#endif
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
#ifdef __gfx12__
intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
#else
ignore = a;
ignore = b;
ignore = reg_c;
#endif
}
}
};
template <index_t WaveSize>
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12,
WaveSize,
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
{
// Absolute fixing property
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
// Wave mode dependent propety
static constexpr index_t wave_size = Number<WaveSize>{};
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
if constexpr(wave_size == 32)
{
#ifdef __gfx12__
intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
#else
ignore = a;
ignore = b;
ignore = reg_c;
#endif
}
}
};
template <typename src_type_a,
typename src_type_b,
typename dst_type,
@@ -463,6 +607,31 @@ struct WmmaSelector
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
#endif
template <>
constexpr auto GetWmma<f8_t, f8_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
}
template <>
constexpr auto GetWmma<f8_t, bf8_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
}
template <>
constexpr auto GetWmma<bf8_t, f8_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
}
template <>
constexpr auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
{
return WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
}
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
static constexpr auto selected_wmma =
wmma_type<GetWmma<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
@@ -612,14 +781,17 @@ struct WmmaGemm
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
is_same<dst_type, bhalf_t>::value) ||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
is_same<dst_type, int32_t>::value)
is_same<dst_type, int32_t>::value) ||
((is_same<src_type_a, f8_t>::value || is_same<src_type_a, bf8_t>::value) &&
(is_same<src_type_b, f8_t>::value || is_same<src_type_b, bf8_t>::value) &&
is_same<dst_type, float>::value) ||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
is_same<dst_type, int32_t>::value)
(is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
is_same<dst_type, int32_t>::value) ||
#endif
,
false,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
"(int8, int32) or (int4, int32)!");
"((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
if constexpr(!TransposeC)
{