From e40ab20b9e229ebd25c9a36f2e493fdf0d3296d5 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Mon, 3 Nov 2025 08:39:43 +0000 Subject: [PATCH] Clarifying the using of CK_TILE_HOST and CK_TILE_HOST_DEVICE trying to save compiling time --- .../hstu_attention_epilogue.hpp | 4 +- .../hstu_attention_fwd_kernel.hpp | 2 +- ...hstu_attention_no_softmax_fwd_pipeline.hpp | 6 +-- ...tention_no_softmax_fwd_trload_pipeline.hpp | 6 +-- ...tu_attention_with_softmax_fwd_pipeline.hpp | 6 +-- ...ntion_with_softmax_fwd_trload_pipeline.hpp | 6 +-- .../18_hstu_attention/hstu_block_masking.hpp | 48 ++----------------- 7 files changed, 18 insertions(+), 60 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp index 76daebc455..d287c3ec33 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_epilogue.hpp @@ -18,7 +18,7 @@ struct NRepetitions2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; - CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } + CK_TILE_DEVICE static constexpr index_t GetSmemSize() { return 0; } template ; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); } @@ -121,7 +121,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename PComputeElementFunction, typename OAccElementFunction, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile @@ -516,7 +516,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp index 7d4a1147f7..94fa7b1494 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_no_softmax_fwd_trload_pipeline.hpp @@ -108,7 +108,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad using DropoutType = std::conditional_t; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); } @@ -125,7 +125,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename PComputeElementFunction, typename OAccElementFunction, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile @@ -507,7 +507,7 @@ struct HstuAttentionNoSoftmaxFwdPipelineQRKSVSTrLoad typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp index 4adb7bf00c..5cd5ec58e9 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_pipeline.hpp @@ -108,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS using DropoutType = std::conditional_t; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); } @@ -125,7 +125,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename PComputeElementFunction, typename OAccElementFunction, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile @@ -612,7 +612,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVS typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp index 7b0262c598..04b9584875 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_with_softmax_fwd_trload_pipeline.hpp @@ -108,7 +108,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad using DropoutType = std::conditional_t; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + CK_TILE_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return Policy::template GetSmemSize(); } @@ -125,7 +125,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename PComputeElementFunction, typename OAccElementFunction, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const QElementFunction& q_element_func, const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile @@ -599,7 +599,7 @@ struct HstuAttentionWithSoftmaxFwdPipelineQRKSVSTrLoad typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename HstuMask> - CK_TILE_HOST_DEVICE auto + CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*KSubQKHeaddim tile const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp index 22b0fdbfe6..f8e53ebd05 100644 --- a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -69,7 +69,7 @@ struct HstuBlockMaskWithLocal // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) // i_y is the start offset of the current tile along the seqlen_q dimension template - CK_TILE_HOST_DEVICE constexpr auto + CK_TILE_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { // handle two special cases first @@ -339,7 +339,7 @@ struct HstuBlockMaskNoLocal // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) // i_y is the start offset of the current tile along the seqlen_q dimension template - CK_TILE_HOST_DEVICE constexpr auto + CK_TILE_DEVICE constexpr auto GetTileRangeAlongX(index_t i_y, number, number) const { if constexpr(!IsMasking) @@ -368,7 +368,7 @@ struct HstuBlockMaskNoLocal }; } - CK_TILE_HOST bool IsTokenPairInsideMask(int row, int col) + CK_TILE_HOST_DEVICE bool IsTokenPairInsideMask(int row, int col) { int row_id; int col_id; @@ -406,48 +406,6 @@ struct HstuBlockMaskNoLocal }; }; - CK_TILE_DEVICE bool IsTokenPairInsideMask(int row, int col) - { - int row_id; - int col_id; - - if(contextual_seqlen > 0) - { - // row_id/col_id is clamped from physical row/col according to contextual_seqlen and - // max_uih_len - row_id = max(row - contextual_seqlen + 1, 0); - col_id = max(col - contextual_seqlen + 1, 0); - - row_id = min(row_id, max_row_id); - col_id = min(col_id, max_col_id); - - if(row_id == 0 && col_id < max_col_id) - return true; - } - else - { - // row_id/col_id is clamped from physical row/col according to contextual_seqlen and - // max_uih_len - row_id = min(row, max_row_id); - col_id = min(col, max_col_id); - }; - - // use row_id/col_id to check the dist between two q/k token pair, token pairs on the - // diagonal line are always considerred - if constexpr(IsMasking) - { - bool res = ((row_id > col_id) || (row == col)); - - return res; - } - else - { - bool res = ((row_id != col_id) || (row == col)); - - return res; - }; - }; - // if the whole tile inside the masking area, no need for pixel-by-pixel checking template CK_TILE_DEVICE constexpr bool IsFullTileInsideMask(index_t i_tile_top,