From c26c60db4c5c7ab86ffc800fc0ae8e87c945c596 Mon Sep 17 00:00:00 2001 From: "PoYen, Chen" Date: Tue, 23 Jul 2024 02:59:17 +0000 Subject: [PATCH] Unify parameter/variable naming style --- .../block_fmha_fwd_appendkv_pipeline.hpp | 115 ++++++++---------- 1 file changed, 53 insertions(+), 62 deletions(-) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp index 22f2bd9f34..77a960dab6 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp @@ -76,27 +76,27 @@ struct BlockFmhaFwdAppendKVPipeline return Policy::template GetSmemSize(); } - template + typename RotaryCosDramBlockWindow, + typename RotarySinDramBlockWindow> CK_TILE_HOST_DEVICE auto - operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + operator()(const QDramBlockWindow& q_dram_block_window, // M0*K0 tile const QElementFunction& q_element_func, - KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile - const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, // N0*K0 tile + KDramBlockWindow& k_dram_block_window, // N0*K0 tile + const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile const KnewElementFunction& knew_element_func, - VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile - const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, // N1*K1 tile + VDramBlockWindow& v_dram_block_window, // N1*K1 tile + const VnewDramBlockWindow& vnew_dram_block_window, // N1*K1 tile const VnewElementFunction& vnew_element_func, - const RotaryCosBlockWindowTemp rotary_cos_block_window_tmp, - const RotarySinBlockWindowTemp rotary_sin_block_window_tmp, + const RotaryCosDramBlockWindow rotary_cos_dram_block_window, + const RotarySinDramBlockWindow rotary_sin_dram_block_window, void* smem_ptr, index_t rotary_dim = 0) const { @@ -154,34 +154,29 @@ struct BlockFmhaFwdAppendKVPipeline #endif }; - auto knew_dram_block_window = - make_tile_window(knew_dram_block_window_tmp.get_bottom_tensor_view(), - knew_dram_block_window_tmp.get_window_lengths(), - {0, 0}); - - auto knew_dram_window = + auto knew_window = make_tile_window(knew_dram_block_window.get_bottom_tensor_view(), knew_dram_block_window.get_window_lengths(), knew_dram_block_window.get_window_origin(), Policy::template MakeKnewDramTileDistribution()); auto knew_tile = [&]() { - auto knew = load_tile(knew_dram_window); + auto knew = load_tile(knew_window); return tile_elementwise_in(knew_element_func, knew); }(); if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE) { auto rotary_cos_window = - make_tile_window(rotary_cos_block_window_tmp.get_bottom_tensor_view(), - rotary_cos_block_window_tmp.get_window_lengths(), - rotary_cos_block_window_tmp.get_window_origin(), + make_tile_window(rotary_cos_dram_block_window.get_bottom_tensor_view(), + rotary_cos_dram_block_window.get_window_lengths(), + rotary_cos_dram_block_window.get_window_origin(), Policy::template MakeRotaryCosSinTileDistribution()); auto rotary_sin_window = - make_tile_window(rotary_sin_block_window_tmp.get_bottom_tensor_view(), - rotary_sin_block_window_tmp.get_window_lengths(), - rotary_sin_block_window_tmp.get_window_origin(), + make_tile_window(rotary_sin_dram_block_window.get_bottom_tensor_view(), + rotary_sin_dram_block_window.get_window_lengths(), + rotary_sin_dram_block_window.get_window_origin(), Policy::template MakeRotaryCosSinTileDistribution()); // We assume that each thread owns contiguous elements on head dimention. And we will @@ -225,11 +220,11 @@ struct BlockFmhaFwdAppendKVPipeline { bool is_left = (start_x + KPerThread) <= (rotary_dim / 2); - auto knew_other_dram_window = knew_dram_window; - move_tile_window(knew_other_dram_window, + auto knew_other_window = knew_window; + move_tile_window(knew_other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)}); - auto knew_other_tile = load_tile(knew_other_dram_window); + auto knew_other_tile = load_tile(knew_other_window); move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)}); auto rotary_cos_tile = load_tile(rotary_cos_window); @@ -253,53 +248,49 @@ struct BlockFmhaFwdAppendKVPipeline } } print_tile(knew_tile, 7); - store_tile(k_dram_block_window_tmp, knew_tile); + store_tile(k_dram_block_window, knew_tile); - auto vnew_dram_block_window = - make_tile_window(vnew_dram_block_window_tmp.get_bottom_tensor_view(), - vnew_dram_block_window_tmp.get_window_lengths(), - {0, 0}); - - auto vnew_dram_window = + auto vnew_window = make_tile_window(vnew_dram_block_window.get_bottom_tensor_view(), vnew_dram_block_window.get_window_lengths(), vnew_dram_block_window.get_window_origin(), Policy::template MakeVnewDramTileDistribution()); auto vnew_tile = [&]() { - auto vnew = load_tile(vnew_dram_window); + auto vnew = load_tile(vnew_window); return tile_elementwise_in(vnew_element_func, vnew); }(); - store_tile(v_dram_block_window_tmp, vnew_tile); + store_tile(v_dram_block_window, vnew_tile); } - template - CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, - KDramBlockWindowTmp& k_dram_block_window_tmp, - const KnewDramBlockWindowTmp& knew_dram_block_window_tmp, - VDramBlockWindowTmp& v_dram_block_window_tmp, - const VnewDramBlockWindowTmp& vnew_dram_block_window_tmp, - const RotaryCosBlockWindowTemp& rotary_cos_block_window_tmp, - const RotarySinBlockWindowTemp& rotary_sin_block_window_tmp, - void* smem_ptr, - index_t rotary_dim = 0) const + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindow& q_dram_block_window, + KDramBlockWindow& k_dram_block_window, + const KnewDramBlockWindow& knew_dram_block_window, + VDramBlockWindow& v_dram_block_window, + const VnewDramBlockWindow& vnew_dram_block_window, + const RotaryCosDramBlockWindow& rotary_cos_dram_block_window, + const RotarySinDramBlockWindow& rotary_sin_dram_block_window, + void* smem_ptr, + index_t rotary_dim = 0) const { - return operator()(q_dram_block_window_tmp, + return operator()(q_dram_block_window, identity{}, - k_dram_block_window_tmp, - knew_dram_block_window_tmp, + k_dram_block_window, + knew_dram_block_window, identity{}, - v_dram_block_window_tmp, - vnew_dram_block_window_tmp, + v_dram_block_window, + vnew_dram_block_window, identity{}, - rotary_cos_block_window_tmp, - rotary_sin_block_window_tmp, + rotary_cos_dram_block_window, + rotary_sin_dram_block_window, smem_ptr, rotary_dim); }