From 19fc2a9051419d9dcb5c10746ddda43120911888 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Wed, 20 Aug 2025 08:35:51 +0000 Subject: [PATCH] Remove selectable VLayout for simplifying the codes since hdim is always fatest dimension --- .../hstu_attention_fwd_kernel.hpp | 61 ++-- .../hstu_attention_fwd_pipeline.hpp | 38 +-- ..._attention_fwd_pipeline_default_policy.hpp | 260 ++++-------------- .../hstu_attention_fwd_setting.hpp | 24 +- .../hstu_attention_fwd_type_config.hpp | 2 - .../hstu_attention_tile_setting_define.hpp | 9 +- 6 files changed, 90 insertions(+), 304 deletions(-) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp index 788cb719b1..382aa75979 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -41,8 +41,6 @@ struct HstuAttentionFwdKernel using BiasDataType = ck_tile::remove_cvref_t; using ODataType = ck_tile::remove_cvref_t; - using VLayout = ck_tile::remove_cvref_t; - static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged; static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK; @@ -626,14 +624,8 @@ struct HstuAttentionFwdKernel batch_offset_q = query_start * kargs.seq_stride_q; batch_offset_k = key_start * kargs.seq_stride_k; - if constexpr(std::is_same_v) - { - batch_offset_v = key_start * kargs.seq_stride_v; - } - else - { - batch_offset_v = key_start; - } + batch_offset_v = key_start * kargs.seq_stride_v; + if constexpr(kHasBias) { batch_offset_bias = query_start * kargs.seq_stride_bias; @@ -759,41 +751,24 @@ struct HstuAttentionFwdKernel sequence{}); }(); const auto v_dram = [&]() { - if constexpr(std::is_same_v) - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.seqlen, kargs.hdim_v), - make_tuple(kargs.seq_stride_v, 1), - number{}, - number<1>{}); + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(kargs.seq_stride_v, 1), + number{}, + number<1>{}); - const auto v_dram_transposed = - transform_tensor_view(v_dram_naive, - make_tuple(make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.seqlen)), - make_tuple(sequence<1>{}, sequence<0>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return pad_tensor_view(v_dram_transposed, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - const auto v_dram_naive = make_naive_tensor_view( - v_ptr, - make_tuple(kargs.hdim_v, kargs.seqlen), - make_tuple(kargs.seq_stride_v, 1), - number{}, - number<1>{}); - - return pad_tensor_view(v_dram_naive, - make_tuple(number{}, - number{}), - sequence{}); - } + return pad_tensor_view(v_dram_transposed, + make_tuple(number{}, + number{}), + sequence{}); }(); auto q_dram_window = diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp index 1a01151972..d812c2d0ee 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -24,7 +24,6 @@ struct HstuAttentionFwdPipelineQRKSVS using HstuMask = remove_cvref_t; using HstuAttentionTileSetting = remove_cvref_t; - using VLayout = remove_cvref_t; static constexpr index_t kBlockSize = Problem::kBlockSize; @@ -54,12 +53,8 @@ struct HstuAttentionFwdPipelineQRKSVS kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ(); static constexpr index_t kAlignmentK = kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); - static constexpr index_t kAlignmentV = []() { - if constexpr(std::is_same_v) - return Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); - else - return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); - }(); + static constexpr index_t kAlignmentV = + Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); static constexpr index_t kAlignmentO = kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); @@ -500,27 +495,16 @@ struct HstuAttentionFwdPipelineQRKSVS tile_elementwise_inout(f_silu, pcomp_tile); - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegTileDistribution()); - shuffle_tile(v_shuffle_tmp, v_tile); + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegTileDistribution()); + shuffle_tile(v_shuffle_tmp, v_tile); - // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer - // i+2, No overlap occurs between V and K in the same unroll, and V in current - // unroll and K in next unroll or first unroll in next iteration - store_tile( - v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer - // i+2, No overlap occurs between V and K in the same unroll, and V in current - // unroll and K in next unroll or first unroll in next iteration - store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], - tile_elementwise_in(v_element_func, v_tile)); // store the prefetch - }; + // if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer + // i+2, No overlap occurs between V and K in the same unroll, and V in current + // unroll and K in next unroll or first unroll in next iteration + store_tile( + v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch if constexpr(kHasDropout) { diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp index 6a01e5daba..d26913deaa 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -147,7 +147,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() { - using VLayout = remove_cvref_t; using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; @@ -156,24 +155,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords - if constexpr(std::is_same_v) - { - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + constexpr index_t kMinVecLoad = 4 / sizeof(VDataType); - constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad - : (ElemPerThread / kMinVecLoad); + constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad) + ? kMaxVecLoad + : (ElemPerThread / kMinVecLoad); - return kVecLoad; - } - else // Similar to GetAlignmentK() - { - constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); - return min(ElemPerThread, MaxVectorSize); - } + return kVecLoad; } template @@ -201,38 +191,14 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize() { - using VLayout = remove_cvref_t; - constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords - if constexpr(std::is_same_v) - { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize(); - return N0 * (N1 * kKPerBlock + kKPack); - } - else // similar to GetKSingleSmemElementSpaceSize() - { - constexpr index_t kKPack = GetSmemKPackV(); - constexpr index_t kKVector = GetAlignmentV(); - - if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) - { - static_assert(kKVector == kKPack); - - return kKPerBlock * kNPerBlock + kKPerBlock; - } - else - { - static_assert(kKVector % kKPack == 0); - - return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; - }; - }; + return N0 * (N1 * kKPerBlock + kKPack); }; template @@ -445,202 +411,80 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() { - using VLayout = remove_cvref_t; - constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers(); constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords - if constexpr(std::is_same_v) - { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - // K2 is the vector size for storing shuffled tile to LDS - constexpr index_t K2 = ElemPerThread / N1; + // K2 is the vector size for storing shuffled tile to LDS + constexpr index_t K2 = ElemPerThread / N1; - // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm - constexpr index_t kKPack = GetSmemKPackV(); + // GetSmemKPackV() is the vector size for loading from LDS by BlockGemm + constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack >= K2, "Check failed!"); + static_assert(kKPack >= K2, "Check failed!"); - constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); + constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack); - static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); + static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize()); - constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); + constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize(); - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple( - number{}, number{}, number{}, number{}), - make_tuple(number{}, - number{}, - number{}, - number<1>{}), - number<8>{}, - number<1>{}); + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}, number{}), + make_tuple(number{}, + number{}, + number{}, + number<1>{}), + number<8>{}, + number<1>{}); - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_merge_transform( - make_tuple(number{}, number{}, number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple(make_merge_transform( + make_tuple(number{}, number{}, number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<0, 1, 2>{}, sequence<3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return v_lds_block_desc; - } - else // Similar to MakeKLdsBlockDescriptor() - { - constexpr index_t kKPack = GetSmemKPackV(); - constexpr index_t kKVector = GetAlignmentV(); - - if constexpr(GetKVWarpGemmKPerThreadSize() >= 8) - { - static_assert(kKVector == kKPack); - - constexpr index_t VSingleSmemElementSpaceSize = - kKPerBlock * kNPerBlock + kKPerBlock; - - static_assert(VSingleSmemElementSpaceSize == - GetVSingleSmemElementSpaceSize()); - - constexpr index_t SingleSmemElementSpaceSize = - GetSingleSmemElementSpaceSize(); - - constexpr auto v_lds_block_desc_0 = - make_naive_tensor_descriptor(make_tuple(number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 2>{}, sequence<1, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return v_lds_block_desc; - } - else - { - static_assert(kKVector % kKPack == 0); - - constexpr index_t VSingleSmemElementSpaceSize = - kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector; - - static_assert(VSingleSmemElementSpaceSize == - GetVSingleSmemElementSpaceSize()); - - constexpr index_t SingleSmemElementSpaceSize = - GetSingleSmemElementSpaceSize(); - - constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}), - make_tuple(number{}, - number{}, - number{}, - number{}, - number<1>{}), - number{}, - number<1>{}); - - constexpr auto v_lds_block_desc = transform_tensor_descriptor( - v_lds_block_desc_0, - make_tuple(make_merge_transform( - make_tuple(number{}, number{})), - make_merge_transform(make_tuple(number{}, - number{}, - number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return v_lds_block_desc; - }; - } + return v_lds_block_desc; } template CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() { - using VLayout = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; - // Need special consideration for RowMajor since shuffling is needed to write LDS in dwords - if constexpr(std::is_same_v) - { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; - constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(ElemPerThread % N1 == 0); + static_assert(ElemPerThread % N1 == 0); - constexpr index_t K2 = ElemPerThread / N1; - constexpr index_t K1 = get_warp_size() / N0; - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t K2 = ElemPerThread / N1; + constexpr index_t K1 = get_warp_size() / N0; + constexpr index_t K0 = kBlockSize / get_warp_size(); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<2, 1>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<2, 1>>{}); - } - else // Similar to MakeKDramTileDistribution() - { - using QKVDataType = remove_cvref_t; - - constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; - - constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - - constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); - } + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, + tuple, sequence<1, 0>>, + sequence<2, 1>, + sequence<2, 1>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution() { - // This tile-distribuiton only used when V layout is RowMajor - using VLayout = remove_cvref_t; - static_assert(std::is_same_v); - constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1; constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp index 772aba2bd0..3cbb31fd75 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -63,8 +63,7 @@ struct HstuAttentionFwdTileSetting<32> typename HstuAttentionFwdBlockTile<32>::gemm0_warps, HstuAttentionFwdWarpTile1, typename HstuAttentionFwdBlockTile<32>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; template <> @@ -75,8 +74,7 @@ struct HstuAttentionFwdTileSetting<64> typename HstuAttentionFwdBlockTile<64>::gemm0_warps, HstuAttentionFwdWarpTile1, typename HstuAttentionFwdBlockTile<64>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; template <> @@ -87,8 +85,7 @@ struct HstuAttentionFwdTileSetting<128> typename HstuAttentionFwdBlockTile<128>::gemm0_warps, HstuAttentionFwdWarpTile1, typename HstuAttentionFwdBlockTile<128>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; template <> @@ -99,8 +96,7 @@ struct HstuAttentionFwdTileSetting<256> typename HstuAttentionFwdBlockTile<256>::gemm0_warps, HstuAttentionFwdWarpTile1, typename HstuAttentionFwdBlockTile<256>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; #endif @@ -147,8 +143,7 @@ struct HstuAttentionFwdTileSetting<32> typename HstuAttentionFwdBlockTile<32>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionFwdBlockTile<32>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; template <> @@ -159,8 +154,7 @@ struct HstuAttentionFwdTileSetting<64> typename HstuAttentionFwdBlockTile<64>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionFwdBlockTile<64>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; template <> @@ -171,8 +165,7 @@ struct HstuAttentionFwdTileSetting<128> typename HstuAttentionFwdBlockTile<128>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionFwdBlockTile<128>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; template <> @@ -183,7 +176,6 @@ struct HstuAttentionFwdTileSetting<256> typename HstuAttentionFwdBlockTile<256>::gemm0_warps, HstuAttentionFwdWarpTile2, typename HstuAttentionFwdBlockTile<256>::gemm1_warps, - HstuAttentionFwdWarpTile1, - IsVLayoutRowMajor>; + HstuAttentionFwdWarpTile1>; }; #endif diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp index 4afc9e14d2..83d2682eb6 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp @@ -30,5 +30,3 @@ struct HstuAttentionFwdTypeConfig using OaccDataType = GemmAccDataType; using ODataType = ck_tile::bf16_t; }; - -static constexpr bool IsVLayoutRowMajor = true; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp index a44c33fddd..48aec57f17 100644 --- a/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp +++ b/example/ck_tile/18_hstu_attention/hstu_attention_tile_setting_define.hpp @@ -27,8 +27,7 @@ template + typename Gemm1WarpTile_> struct HstuAttentionFwdTileSettingClass { using BlockTile = remove_cvref_t; @@ -56,12 +55,6 @@ struct HstuAttentionFwdTileSettingClass static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0"); static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim); - - // v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen - static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_; - using VLayout = std::conditional_t; }; } // namespace ck_tile