mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 03:19:48 +00:00
Use packed cast_tile for fp16
This commit is contained in:
@@ -249,7 +249,7 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [](CompDataType& x) {
|
||||
auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
const auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
return x = x / (one + exp(-x));
|
||||
};
|
||||
@@ -442,7 +442,13 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
const auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, s));
|
||||
const auto p = [&]() {
|
||||
if constexpr(std::is_same_v<PDataType, fp16_t>)
|
||||
return impl::cast_tile_pk_fp16_fp32<PDataType>(
|
||||
tile_elementwise_in(p_compute_element_func, s));
|
||||
else
|
||||
return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, s));
|
||||
}();
|
||||
|
||||
move_tile_window(k_dram_window, {kN0, -(k0_loops - 1) * kK0});
|
||||
k_tile = load_tile(k_dram_window);
|
||||
|
||||
Reference in New Issue
Block a user