From 33f4f876cffd32621df945fdc06cbb5cd8aead53 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 28 Jan 2026 16:32:54 +0000 Subject: [PATCH] Clean up tile_elementwise (casting not needed with new approach) --- .../ck_tile/core/tensor/tile_elementwise.hpp | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 8a6eb90cfb..bc6d7d2f5a 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -282,51 +282,6 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors) return out_dstr_tensor; } -template -CK_TILE_DEVICE auto cast_tile_pk_bf16_bf8(const InTensor& in_dstr_tensors) -{ -#if defined(__gfx950__) - // This API is designed to use the _pk_ serious of function - constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); - - constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); - static_assert(thread_buffer_size % 2 == 0); - constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; - - auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); - - union - { - uint16_t i16val; - bf8_t i8val[2]; - } input; - - union - { - bf16x2_t bhalf_vec; - bf16_t bhalf_arr[2]; - } output; - - // TODO: this is rtz cvt, need be very careful - for(index_t i = 0; i < thread_buffer_size_pk; i++) - { - input.i8val[0] = in_dstr_tensors.get_thread_buffer()[2 * i + 0]; - input.i8val[1] = in_dstr_tensors.get_thread_buffer()[2 * i + 1]; - output.bhalf_vec = - __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0); - - out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = output.bhalf_arr[0]; - out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = output.bhalf_arr[1]; - } - - return out_dstr_tensor; -#else - // fallback - return tile_elementwise_in(type_convert, - in_dstr_tensors); -#endif -} - #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -399,10 +354,6 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) return impl::cast_tile_pk_fp8_fp32(src_tensor); - else if constexpr((std::is_same_v) && - std::is_same_v && - (SrcTensor::get_thread_buffer_size() % 2 == 0)) - return impl::cast_tile_pk_bf16_bf8(src_tensor); #if CK_TILE_USE_PK_FP16_TILE_CAST else if constexpr(std::is_same_v && std::is_same_v &&