mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE][FA] using pk f16_f32 (#1343)
* [CK_TILE][FA] using pk f16_f32
* correct a error
[ROCm/composable_kernel commit: 17ed368f58]
This commit is contained in:
@@ -61,10 +61,13 @@ CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
asm volatile("\
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
s_barrier \
|
||||
" ::);
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
|
||||
@@ -167,6 +167,10 @@
|
||||
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
|
||||
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
|
||||
#endif
|
||||
|
||||
// TODO: better solve this inside compiler
|
||||
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
|
||||
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
|
||||
|
||||
@@ -110,7 +110,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
|
||||
namespace impl {
|
||||
// TODO: this is ugly
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx94__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
@@ -156,6 +156,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
|
||||
static_assert(thread_buffer_size % 2 == 0);
|
||||
constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
|
||||
|
||||
auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
|
||||
|
||||
// TODO: this is rtz cvt, need be very careful
|
||||
for(index_t i = 0; i < thread_buffer_size_pk; i++)
|
||||
{
|
||||
auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
|
||||
in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
|
||||
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
|
||||
out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
|
||||
}
|
||||
|
||||
return out_dstr_tensor;
|
||||
#else
|
||||
// fallback
|
||||
return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
|
||||
in_dstr_tensors);
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
// this function assume either src or dst (or both) date type is under 1 dword
|
||||
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
|
||||
@@ -229,8 +260,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
|
||||
float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 4 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp8x4<DstType, SrcTensor>(src_tensor);
|
||||
return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#if CK_TILE_USE_PK_FP16_TILE_CAST
|
||||
else if constexpr(std::is_same_v<DstType, fp16_t> &&
|
||||
std::is_same_v<typename SrcTensor::DataType, float> &&
|
||||
(SrcTensor::get_thread_buffer_size() % 2 == 0))
|
||||
{
|
||||
return impl::cast_tile_pk_fp16_fp32<DstType, SrcTensor>(src_tensor);
|
||||
}
|
||||
#endif
|
||||
#if CK_TILE_USE_SUBDWORD_TILE_CAST
|
||||
else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
|
||||
{
|
||||
|
||||
@@ -578,8 +578,14 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
randval_dram_window);
|
||||
}
|
||||
|
||||
const auto p =
|
||||
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
const auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
else
|
||||
return cast_tile<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, p_compute));
|
||||
}();
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
|
||||
Reference in New Issue
Block a user