diff --git a/include/ck_tile/core/arch/arch.hpp b/include/ck_tile/core/arch/arch.hpp index 888f0e728f..4a69f67ae3 100644 --- a/include/ck_tile/core/arch/arch.hpp +++ b/include/ck_tile/core/arch/arch.hpp @@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; } CK_TILE_DEVICE void block_sync_lds() { #if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM - asm volatile("\ - s_waitcnt lgkmcnt(0) \n \ - s_barrier \ - " ::); + // asm volatile("\ + // s_waitcnt lgkmcnt(0) \n \ + // s_barrier \ + // " ::); + + __builtin_amdgcn_s_waitcnt(0xc07f); + __builtin_amdgcn_s_barrier(); #else __syncthreads(); #endif diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 10045d8f7d..344343d931 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -167,6 +167,10 @@ #define CK_TILE_USE_SUBDWORD_TILE_CAST 0 #endif +#ifndef CK_TILE_USE_PK_FP16_TILE_CAST +#define CK_TILE_USE_PK_FP16_TILE_CAST 0 +#endif + // TODO: better solve this inside compiler #ifndef CK_TILE_FMHA_FWD_FAST_EXP2 #define CK_TILE_FMHA_FWD_FAST_EXP2 0 diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 48762b7225..5fecd19dcd 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor) namespace impl { // TODO: this is ugly template -CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) +CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors) { #if defined(__gfx94__) // This API is designed to use the _pk_ serious of function @@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors) #endif } +template +CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors) +{ +#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 2 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + // TODO: this is rtz cvt, need be very careful + for(index_t i = 0; i < thread_buffer_size_pk; i++) + { + auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0], + in_dstr_tensors.get_thread_buffer()[2 * i + 1]); + + out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x; + out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y; + } + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) float> && (SrcTensor::get_thread_buffer_size() % 4 == 0)) { - return impl::cast_tile_pk_fp8x4(src_tensor); + return impl::cast_tile_pk_fp8_fp32(src_tensor); } +#if CK_TILE_USE_PK_FP16_TILE_CAST + else if constexpr(std::is_same_v && + std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 2 == 0)) + { + return impl::cast_tile_pk_fp16_fp32(src_tensor); + } +#endif #if CK_TILE_USE_SUBDWORD_TILE_CAST else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4) { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp index 9939a474b2..21784fc2d2 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp @@ -578,8 +578,14 @@ struct BlockFmhaPipelineQRKSVSAsync randval_dram_window); } - const auto p = - cast_tile(tile_elementwise_in(p_compute_element_func, p_compute)); + const auto p = [&]() { + if constexpr(std::is_same_v) + return impl::cast_tile_pk_fp16_fp32( + tile_elementwise_in(p_compute_element_func, p_compute)); + else + return cast_tile( + tile_elementwise_in(p_compute_element_func, p_compute)); + }(); // STAGE 3, KV gemm if constexpr(k1_loops > 1)