From ac4f4ffb790b796a3f153d7ff24d596b5dbc4af2 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Fri, 21 Nov 2025 02:50:26 +0800 Subject: [PATCH] [CK_TILE] Refine FP32 => FP16/BF16 Conversion (#3215) * [CK_TILE] Refine FP32 => FP16/BF16 Conversion * Thank you Copilot * Rename fix * Fix example * Fix accu checking * Fix * Fix [ROCm/composable_kernel commit: 8b284a63a4d7c99c41b6885ac37dbb7874c8737d] --- include/ck_tile/core/numeric/bfloat16.hpp | 15 +++++++- include/ck_tile/core/numeric/half.hpp | 6 +++ include/ck_tile/core/numeric/type_convert.hpp | 3 ++ .../ck_tile/core/tensor/tile_elementwise.hpp | 38 +++++++++++++++---- ..._batch_prefill_pipeline_qr_ks_vs_async.hpp | 4 +- .../block_fmha_pipeline_qr_ks_vs_async.hpp | 4 +- .../grouped_gemm_multi_d/CMakeLists.txt | 5 ++- 7 files changed, 61 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/core/numeric/bfloat16.hpp b/include/ck_tile/core/numeric/bfloat16.hpp index b17890b733..5caee28e2e 100644 --- a/include/ck_tile/core/numeric/bfloat16.hpp +++ b/include/ck_tile/core/numeric/bfloat16.hpp @@ -283,7 +283,10 @@ template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant = {}) { -#if CK_TILE_USE_LLVM_BUILTIN_BF16 +// Use builtin bfloat16 conversion only on gfx950 as its predecessors do not support bf16 cvt +// instructions, resulting in suboptimal performance; Add host side marcro check for consistency +// during accuracy tests. +#if CK_TILE_USE_LLVM_BUILTIN_BF16 && (defined(__gfx950__) || defined(CK_GFX950_SUPPORT)) return static_cast(f); #else return bit_cast(float_to_bf16_raw(f, constant{})); @@ -427,4 +430,14 @@ bfloat16_t exp2(bfloat16_t x) { return static_cast(exp2f(static_cast CK_TILE_DEVICE bfloat16_t log(bfloat16_t x) { return static_cast(__logf(static_cast(x))); }; +using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2))); +using fp32x2_t = float __attribute__((ext_vector_type(2))); + +template (CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)> +CK_TILE_HOST_DEVICE constexpr bf16x2_t fp32x2_to_bf16x2(const fp32x2_t& x) +{ + return bf16x2_t{float_to_bf16(x.x), float_to_bf16(x.y)}; +} + } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/half.hpp b/include/ck_tile/core/numeric/half.hpp index 8479b33f8f..128befe90f 100644 --- a/include/ck_tile/core/numeric/half.hpp +++ b/include/ck_tile/core/numeric/half.hpp @@ -383,6 +383,7 @@ half_t log(half_t x) { return static_cast(__logf(static_cast(x))) #endif using fp16x2_t = _Float16 __attribute__((ext_vector_type(2))); +using fp32x2_t = float __attribute__((ext_vector_type(2))); CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) { @@ -401,4 +402,9 @@ CK_TILE_DEVICE fp16x2_t pk_add_f16(const fp16x2_t& x, const fp16x2_t& y) return c; } +CK_TILE_HOST_DEVICE +constexpr fp16x2_t fp32x2_to_fp16x2(const fp32x2_t& x) +{ + return fp16x2_t{float_to_fp16(x.x), float_to_fp16(x.y)}; +} } // namespace ck_tile diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index 1455fce0ea..3fee3ef96c 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -64,6 +64,9 @@ CK_TILE_TYPE_CONVERT(bf8_t, bf8, float, float) CK_TILE_TYPE_CONVERT(float, float, int8_t, int8) CK_TILE_TYPE_CONVERT(int8_t, int8, float, float) + +CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2) +CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) #undef CK_TILE_TYPE_CONVERT } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 1863192a1f..4ab4c78884 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -228,7 +228,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) } template -CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) +CK_TILE_DEVICE auto cast_tile_pkrtz_fp16_fp32(const InTensor& in_dstr_tensors) { #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) // This API is designed to use the _pk_ serious of function @@ -258,6 +258,30 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) #endif } +template +CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors) +{ + // This API is designed to help compiler to identify pairs of f32 -> fp16/bf16 cast and use + // cvt_pk instruction when possible + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 2 == 0); + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + using f16x2_t = std::conditional_t, fp16x2_t, bf16x2_t>; + for(index_t i = 0; i < thread_buffer_size / 2; i++) + { + auto o = type_convert(fp32x2_t{ + in_dstr_tensors.get_thread_buffer()[2 * i + 0], + in_dstr_tensors.get_thread_buffer()[2 * i + 1], + }); + + out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x; + out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y; + } + return out_dstr_tensor; +} + #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -329,22 +353,20 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) if constexpr((std::is_same_v || std::is_same_v) && std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) - { return impl::cast_tile_pk_fp8_fp32(src_tensor); - } #if CK_TILE_USE_PK_FP16_TILE_CAST else if constexpr(std::is_same_v && std::is_same_v && (SrcTensor::get_thread_buffer_size() % 2 == 0)) - { - return impl::cast_tile_pk_fp16_fp32(src_tensor); - } + return impl::cast_tile_pkrtz_fp16_fp32(src_tensor); #endif + else if constexpr((std::is_same_v || std::is_same_v) && + std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 2 == 0)) + return impl::cast_tile_pk_fp16bf16_fp32(src_tensor); #if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) - { return impl::cast_tile_opt_subdword(src_tensor); - } #endif else return tile_elementwise_in(type_convert, src_tensor); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp index 6398bf316e..59fa9139bf 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp @@ -704,12 +704,12 @@ struct BlockFmhaBatchPrefillPipelineQRKSVSAsync const auto p = [&]() { #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN // For fp32 to fp16, - // impl::cast_tile_pk_fp16_fp32 would cause precision issue, + // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. return cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); #else if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( + return impl::cast_tile_pkrtz_fp16_fp32( tile_elementwise_in(p_compute_element_func, p_compute)); else return cast_tile( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index b67c28401f..da48802c76 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -657,12 +657,12 @@ struct BlockFmhaPipelineQRKSVSAsync const auto p = [&]() { #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN // For fp32 to fp16, - // impl::cast_tile_pk_fp16_fp32 would cause precision issue, + // impl::cast_tile_pkrtz_fp16_fp32 would cause precision issue, // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero. return cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); #else if constexpr(std::is_same_v) - return impl::cast_tile_pk_fp16_fp32( + return impl::cast_tile_pkrtz_fp16_fp32( tile_elementwise_in(p_compute_element_func, p_compute)); else return cast_tile( diff --git a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt index 20c4cbc1c3..845da28b5d 100644 --- a/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt +++ b/test/ck_tile/grouped_gemm_multi_d/CMakeLists.txt @@ -3,7 +3,10 @@ if(CK_USE_OCP_FP8) list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8) endif() +# Use standard asm for rtn bf16 conversion instead of turncate +list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3) + if(GPU_TARGETS MATCHES "gfx94|gfx95") add_gtest_executable(test_ck_tile_grouped_gemm_multi_d test_grouped_gemm_multi_d.cpp) target_compile_options(test_ck_tile_grouped_gemm_multi_d PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) -endif() \ No newline at end of file +endif()