mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
DeviceGemm_Wmma_CShuffleV3 with BlockGemmPipelineVersion::v3 (#2096)
* Prepare files for DeviceGemm_Wmma_CShuffleV3
* Implement main part of CShuffleV3 with block pipeline v3 for WMMA
* Remove unused functions and template params for A/B descriptors
* Support both gfx11 and gfx12
* Enable SplitK for gfx12 and disable for gfx11
* Added RowColRow layout for DeviceGemmV2 fp16
* Added more instances for Row, Col, Row data layout
* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Row, Row data layout
* Added instances for DeviceGemm_Wmma_CShuffleV3, Col, Col, Row data layout
* Added more instances for DeviceGemm_Wmma_CShuffleV3, Row, Row, Row data layout
* Fix formatting
* Add documentation
Based on e5ad48a784
* Enable gemm_universal profiling for gfx11/12
* Add WMMA intrinsics for F8/BF8
* Support F8/BF8 DeviceGemm_Wmma_CShuffleV3, add basic instances
* Add BF16 instances and tests
* Fix test_gemm_universal_wmma_fp8 by adding CK_USE_WMMA_FP8
---------
Co-authored-by: Anca Hamuraru <anca@streamhpc.com>
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -22,6 +22,10 @@ enum struct WmmaInstr
|
||||
wmma_f32_16x16x16_f16_gfx12,
|
||||
wmma_f32_16x16x16_bf16_gfx12,
|
||||
wmma_i32_16x16x16_iu8_gfx12,
|
||||
wmma_f32_16x16x16_f8f8_gfx12,
|
||||
wmma_f32_16x16x16_f8bf8_gfx12,
|
||||
wmma_f32_16x16x16_bf8f8_gfx12,
|
||||
wmma_f32_16x16x16_bf8bf8_gfx12,
|
||||
};
|
||||
|
||||
/*
|
||||
@@ -400,6 +404,146 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
intrin_wmma_f32_16x16x16_f8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
#else
|
||||
ignore = a;
|
||||
ignore = b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
intrin_wmma_f32_16x16x16_f8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
#else
|
||||
ignore = a;
|
||||
ignore = b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
intrin_wmma_f32_16x16x16_bf8f8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
#else
|
||||
ignore = a;
|
||||
ignore = b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t WaveSize>
|
||||
struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12,
|
||||
WaveSize,
|
||||
typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
|
||||
{
|
||||
// Absolute fixing property
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
|
||||
// Wave mode dependent propety
|
||||
static constexpr index_t wave_size = Number<WaveSize>{};
|
||||
static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
|
||||
static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
|
||||
|
||||
template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
|
||||
if constexpr(wave_size == 32)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
intrin_wmma_f32_16x16x16_bf8bf8_w32_gfx12<MPerWmma, NPerWmma>::Run(a, b, reg_c);
|
||||
#else
|
||||
ignore = a;
|
||||
ignore = b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename src_type_a,
|
||||
typename src_type_b,
|
||||
typename dst_type,
|
||||
@@ -463,6 +607,31 @@ struct WmmaSelector
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu4;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
constexpr auto GetWmma<f8_t, f8_t, float, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetWmma<f8_t, bf8_t, float, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetWmma<bf8_t, f8_t, float, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12;
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12;
|
||||
}
|
||||
|
||||
// get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
|
||||
static constexpr auto selected_wmma =
|
||||
wmma_type<GetWmma<src_type_a, src_type_b, dst_type, MPerWmma, NPerWmma>(), Number<32>{}>{};
|
||||
@@ -612,14 +781,17 @@ struct WmmaGemm
|
||||
(is_same<src_type_a, bhalf_t>::value && is_same<src_type_b, bhalf_t>::value &&
|
||||
is_same<dst_type, bhalf_t>::value) ||
|
||||
(is_same<src_type_a, int8_t>::value && is_same<src_type_b, int8_t>::value &&
|
||||
is_same<dst_type, int32_t>::value)
|
||||
is_same<dst_type, int32_t>::value) ||
|
||||
((is_same<src_type_a, f8_t>::value || is_same<src_type_a, bf8_t>::value) &&
|
||||
(is_same<src_type_b, f8_t>::value || is_same<src_type_b, bf8_t>::value) &&
|
||||
is_same<dst_type, float>::value) ||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
|| (is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
|
||||
is_same<dst_type, int32_t>::value)
|
||||
(is_same<src_type_a, int4_t>::value && is_same<src_type_b, int4_t>::value &&
|
||||
is_same<dst_type, int32_t>::value) ||
|
||||
#endif
|
||||
,
|
||||
false,
|
||||
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
|
||||
"(int8, int32) or (int4, int32)!");
|
||||
"((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
|
||||
static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user