mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add fp8 GEMM and an example for it (#767)
* Add fp8 xdl gemm * Add example * Use int8 intrinsics for buffer load/store * Format * Update cmakelists
This commit is contained in:
@@ -44,3 +44,7 @@ if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS
|
||||
add_dependencies(example_gemm_wmma example_gemm_wmma_fp16)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
|
||||
add_example_executable(example_gemm_xdl_f8 gemm_xdl_f8.cpp)
|
||||
add_dependencies(example_gemm_xdl example_gemm_xdl_f8)
|
||||
endif()
|
||||
|
||||
38
example/01_gemm/gemm_xdl_f8.cpp
Normal file
38
example/01_gemm/gemm_xdl_f8.cpp
Normal file
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp"
|
||||
|
||||
using ADataType = ck::f8_t;
|
||||
using BDataType = ck::f8_t;
|
||||
using CDataType = ck::f8_t;
|
||||
using AccDataType = float;
|
||||
using CShuffleDataType = ck::f8_t;
|
||||
|
||||
using ALayout = Row;
|
||||
using BLayout = Col;
|
||||
using CLayout = Row;
|
||||
|
||||
using AElementOp = PassThrough;
|
||||
using BElementOp = PassThrough;
|
||||
using CElementOp = PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
|
||||
|
||||
// clang-format off
|
||||
using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
|
||||
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
// ######| | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 16>;
|
||||
// clang-format on
|
||||
|
||||
using ReferenceGemmInstance = ck::tensor_operation::host::
|
||||
ReferenceGemm<ADataType, BDataType, CDataType, AccDataType, AElementOp, BElementOp, CElementOp>;
|
||||
|
||||
#include "run_gemm_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_gemm_example(argc, argv); }
|
||||
@@ -29,7 +29,9 @@ enum struct MfmaInstr
|
||||
mfma_i32_16x16x16i8,
|
||||
mfma_i32_32x32x16i8,
|
||||
mfma_i32_16x16x32i8,
|
||||
mfma_f64_16x16x4f64
|
||||
mfma_f64_16x16x4f64,
|
||||
mfma_f32_32x32x16f8f8,
|
||||
mfma_f32_16x16x32f8f8
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -454,6 +456,50 @@ struct mfma_type<MfmaInstr::mfma_f64_16x16x4f64>
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x16f8f8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_32x32x16f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x32f8f8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x32f8f8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
{
|
||||
@@ -594,6 +640,18 @@ struct MfmaSelector
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
__host__ __device__ constexpr MfmaSelector()
|
||||
@@ -794,7 +852,7 @@ struct XdlopsGemm
|
||||
{
|
||||
static_assert(is_same<base_type, double>::value || is_same<base_type, float>::value ||
|
||||
is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value ||
|
||||
is_same<base_type, int8_t>::value,
|
||||
is_same<base_type, int8_t>::value || is_same<base_type, f8_t>::value,
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
|
||||
@@ -1114,13 +1114,30 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
return bit_cast<vector_t>(tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
}
|
||||
#else
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -1179,13 +1196,33 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp =
|
||||
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
}
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -354,5 +354,68 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16f8f8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16f8f8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<f8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<f8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32f8f8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<f8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<f8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user