Clean up tile_elementwise (casting not needed with new approach)

This commit is contained in:
Enrico Degregori
2026-01-28 16:32:54 +00:00
parent e907e1bdf1
commit 33f4f876cf

View File

@@ -282,51 +282,6 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors)
return out_dstr_tensor;
}
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_bf16_bf8(const InTensor& in_dstr_tensors)
{
#if defined(__gfx950__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
static_assert(thread_buffer_size % 2 == 0);
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
union
{
uint16_t i16val;
bf8_t i8val[2];
} input;
union
{
bf16x2_t bhalf_vec;
bf16_t bhalf_arr[2];
} output;
// TODO: this is rtz cvt, need be very careful
for(index_t i = 0; i < thread_buffer_size_pk; i++)
{
input.i8val[0] = in_dstr_tensors.get_thread_buffer()[2 * i + 0];
input.i8val[1] = in_dstr_tensors.get_thread_buffer()[2 * i + 1];
output.bhalf_vec =
__builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0);
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = output.bhalf_arr[0];
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = output.bhalf_arr[1];
}
return out_dstr_tensor;
#else
// fallback
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
in_dstr_tensors);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
@@ -399,10 +354,6 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
std::is_same_v<typename SrcTensor::DataType, float> &&
(SrcTensor::get_thread_buffer_size() % 4 == 0))
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
else if constexpr((std::is_same_v<DstType, bf16_t>) &&
std::is_same_v<typename SrcTensor::DataType, bf8_t> &&
(SrcTensor::get_thread_buffer_size() % 2 == 0))
return impl::cast_tile_pk_bf16_bf8<DstType, SrcTensor>(src_tensor);
#if CK_TILE_USE_PK_FP16_TILE_CAST
else if constexpr(std::is_same_v<DstType, fp16_t> &&
std::is_same_v<typename SrcTensor::DataType, float> &&