diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp index 593ce90a31..f9560e08cb 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_pipeline.hpp @@ -118,8 +118,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename QElementFunction, - typename KElementFunction, - typename VElementFunction, typename BiasElementFunction, typename SAccElementFunction, typename PComputeElementFunction, @@ -128,10 +126,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, const SAccElementFunction& s_acc_element_func, @@ -143,9 +139,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; - static_assert( std::is_same_v> && std::is_same_v{}([&](auto i_k1) { - store_tile(k_lds_write_windows[i_k1], - tile_elementwise_in(k_element_func, k_tiles[i_k1])); + store_tile(k_lds_write_windows[i_k1], k_tiles[i_k1]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -464,8 +456,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tile)); + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); @@ -508,7 +499,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number{}]); store_tile(v_lds_windows[number<(i_k1 + 3) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tile)); + v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); }; @@ -546,9 +537,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, - identity{}, v_dram_block_window_tmp, - identity{}, bias_dram_block_window_tmp, identity{}, identity{}, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index acb8288635..9978ffe071 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -120,8 +120,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename QElementFunction, - typename KElementFunction, - typename VElementFunction, typename BiasElementFunction, typename SAccElementFunction, typename PComputeElementFunction, @@ -130,10 +128,8 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, const SAccElementFunction& s_acc_element_func, @@ -145,9 +141,6 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; - static_assert( std::is_same_v> && std::is_same_v{}([&](auto i_k1) { - store_tile(k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[i_k1])); + store_tile(k_lds_write_windows[number{}], k_tiles[i_k1]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -453,7 +445,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad // STAGE 3, Gemm_1 ( O = P@V ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[number{}])); + v_tiles[number{}]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -498,9 +490,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, - identity{}, v_dram_block_window_tmp, - identity{}, bias_dram_block_window_tmp, identity{}, identity{}, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 7bc8445bea..67959ebb85 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -118,8 +118,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename QElementFunction, - typename KElementFunction, - typename VElementFunction, typename BiasElementFunction, typename SAccElementFunction, typename PComputeElementFunction, @@ -128,10 +126,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, const SAccElementFunction& s_acc_element_func, @@ -143,8 +139,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; ignore = scale_p; static_assert( @@ -396,9 +390,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile( - k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + store_tile(k_lds_write_windows[number{}], + k_tiles[number{}]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -505,8 +498,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tile)); + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); @@ -587,8 +579,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number<1>{}]); - store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tile)); + store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); @@ -616,7 +607,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS shuffle_tile(v_shuffled_tile, v_tiles[number{}]); store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffled_tile)); + v_shuffled_tile); __builtin_amdgcn_sched_barrier(0x00000001); }; @@ -668,9 +659,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, - identity{}, v_dram_block_window_tmp, - identity{}, bias_dram_block_window_tmp, identity{}, identity{}, diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 4729d5b33a..fa9f16eee7 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -120,8 +120,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename QElementFunction, - typename KElementFunction, - typename VElementFunction, typename BiasElementFunction, typename SAccElementFunction, typename PComputeElementFunction, @@ -130,10 +128,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, - const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile - const KElementFunction& k_element_func, - const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VElementFunction& v_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasElementFunction& bias_element_func, const SAccElementFunction& s_acc_element_func, @@ -145,8 +141,6 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad void* smem_ptr, DropoutType& dropout) const { - ignore = q_element_func; - ignore = k_element_func; ignore = scale_p; static_assert( @@ -373,9 +367,8 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad { // STAGE 1, Gemm_0 ( S = Q@K ) static_for<0, k1_loops, 1>{}([&](auto i_k1) { - store_tile( - k_lds_write_windows[number{}], - tile_elementwise_in(k_element_func, k_tiles[number{}])); + store_tile(k_lds_write_windows[number{}], + k_tiles[number{}]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -477,8 +470,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[number<0>{}])); + store_tile(v_lds_windows[number<2 % NumKVLdsBuffers>{}], v_tiles[number<0>{}]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -564,8 +556,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_s_barrier(); }; - store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[number<1>{}])); + store_tile(v_lds_windows[number<3 % NumKVLdsBuffers>{}], v_tiles[number<1>{}]); __builtin_amdgcn_sched_barrier(0x00000001); @@ -592,7 +583,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad __builtin_amdgcn_sched_barrier(0x00000001); store_tile(v_lds_windows[number<(i_k1 + 4) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tiles[number{}])); + v_tiles[number{}]); __builtin_amdgcn_sched_barrier(0x00000001); }; @@ -644,9 +635,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad return operator()(q_dram_block_window_tmp, identity{}, k_dram_block_window_tmp, - identity{}, v_dram_block_window_tmp, - identity{}, bias_dram_block_window_tmp, identity{}, identity{},