mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-26 16:04:58 +00:00
Add fp8 @ bf8 gemm support and example (#933)
* Add f8 bf8 gemm example
* Add element-wise ops
* Add intrinsics
* Update reference calculation
* Add an additional type option for xdlops gemm
* Fix build process
* Add bf8 to buffer addressing
* Update blockwise op, split typeA and typeB
* Update for compatibility
* Uppdate naming to f8->fp8
* Update naming
* Format
[ROCm/composable_kernel commit: bd09b5c538]
This commit is contained in:
@@ -1127,37 +1127,53 @@ amd_buffer_load_invalid_element_return_zero(const T* p_src_wave,
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_LOAD_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t src_addr_shift = src_thread_element_valid ? 0 : 0x80000000;
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
return bit_cast<vector_t>(tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
return bit_cast<vector_t>(tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
return amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_addr_shift + src_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp = amd_buffer_load_impl<int8_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? bit_cast<vector_t>(tmp) : vector_t(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
vector_t tmp = amd_buffer_load_impl<scalar_t, vector_size, coherence>(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 0);
|
||||
return src_thread_element_valid ? tmp : vector_t(0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
@@ -1216,40 +1232,61 @@ __device__ void amd_buffer_store(const typename vector_type_maker<T, N>::type::t
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK
|
||||
uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x80000000;
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp =
|
||||
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_addr_shift +
|
||||
dst_thread_addr_offset,
|
||||
0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
#else
|
||||
if(dst_thread_element_valid)
|
||||
{
|
||||
#if defined CK_ENABLE_FP8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
{
|
||||
auto tmp = bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value || is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8
|
||||
}
|
||||
#if defined CK_ENABLE_FP8 && !defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, f8_t>::value)
|
||||
#endif
|
||||
#if !defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
if constexpr(is_same<scalar_t, bf8_t>::value)
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
{
|
||||
auto tmp =
|
||||
bit_cast<typename vector_type_maker<int8_t, vector_size>::type::type>(
|
||||
src_thread_data);
|
||||
amd_buffer_store_impl<int8_t, vector_size, coherence>(
|
||||
tmp, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
#endif
|
||||
amd_buffer_store_impl<scalar_t, vector_size, coherence>(
|
||||
src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0);
|
||||
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_BF8
|
||||
}
|
||||
#endif
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -419,5 +419,70 @@ struct intrin_mfma_f32_16x16x32f8f8<16, 16>
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x16f8bf8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x16f8bf8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<f8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x32f8bf8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_16x16x32f8bf8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
|
||||
bit_cast<long>(reg_a),
|
||||
bit_cast<long>(reg_b),
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
#else
|
||||
vector_type<f8_t, 8> reg_a_v(reg_a);
|
||||
vector_type<bf8_t, 8> reg_b_v(reg_b);
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto k) {
|
||||
float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
|
||||
float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
|
||||
|
||||
intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
|
||||
});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
#endif
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user