[CK_TILE][FA] using pk f16_f32 (#1343)

* [CK_TILE][FA] using pk f16_f32

* correct a error

[ROCm/composable_kernel commit: 17ed368f58]
This commit is contained in:
carlushuang
2024-06-17 17:16:46 +08:00
committed by GitHub
parent 4847f3beb4
commit 447beaec1e
4 changed files with 60 additions and 8 deletions

View File

@@ -578,8 +578,14 @@ struct BlockFmhaPipelineQRKSVSAsync
randval_dram_window);
}
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
const auto p = [&]() {
if constexpr(std::is_same_v<PDataType, fp16_t>)
return impl::cast_tile_pk_fp16_fp32<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
else
return cast_tile<PDataType>(
tile_elementwise_in(p_compute_element_func, p_compute));
}();
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)