Remove selectable VLayout for simplifying the codes since hdim is always fatest dimension

This commit is contained in:
Qianfeng Zhang
2025-08-20 08:35:51 +00:00
parent 15e6be5c79
commit 19fc2a9051
6 changed files with 90 additions and 304 deletions

View File

@@ -41,8 +41,6 @@ struct HstuAttentionFwdKernel
using BiasDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::BiasDataType>;
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
using VLayout = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::VLayout>;
static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged;
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
@@ -626,14 +624,8 @@ struct HstuAttentionFwdKernel
batch_offset_q = query_start * kargs.seq_stride_q;
batch_offset_k = key_start * kargs.seq_stride_k;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
batch_offset_v = key_start * kargs.seq_stride_v;
}
else
{
batch_offset_v = key_start;
}
batch_offset_v = key_start * kargs.seq_stride_v;
if constexpr(kHasBias)
{
batch_offset_bias = query_start * kargs.seq_stride_bias;
@@ -759,41 +751,24 @@ struct HstuAttentionFwdKernel
sequence<false, kPadHeadDimQK>{});
}();
const auto v_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen, kargs.hdim_v),
make_tuple(kargs.seq_stride_v, 1),
number<HstuAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.seqlen, kargs.hdim_v),
make_tuple(kargs.seq_stride_v, 1),
number<HstuAttentionPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view(v_dram_transposed,
make_tuple(number<HstuAttentionPipeline::kN1>{},
number<HstuAttentionPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
}
else
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen),
make_tuple(kargs.seq_stride_v, 1),
number<HstuAttentionPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(v_dram_naive,
make_tuple(number<HstuAttentionPipeline::kN1>{},
number<HstuAttentionPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
}
return pad_tensor_view(v_dram_transposed,
make_tuple(number<HstuAttentionPipeline::kN1>{},
number<HstuAttentionPipeline::kK1>{}),
sequence<kPadHeadDimV, false>{});
}();
auto q_dram_window =

View File

@@ -24,7 +24,6 @@ struct HstuAttentionFwdPipelineQRKSVS
using HstuMask = remove_cvref_t<typename Problem::HstuMask>;
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
using VLayout = remove_cvref_t<typename HstuAttentionTileSetting::VLayout>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -54,12 +53,8 @@ struct HstuAttentionFwdPipelineQRKSVS
kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentV =
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
@@ -500,27 +495,16 @@ struct HstuAttentionFwdPipelineQRKSVS
tile_elementwise_inout(f_silu, pcomp_tile);
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>());
shuffle_tile(v_shuffle_tmp, v_tile);
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
Policy::template MakeShuffledVRegTileDistribution<Problem>());
shuffle_tile(v_shuffle_tmp, v_tile);
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
// i+2, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unroll in next iteration
store_tile(
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
// i+2, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unroll in next iteration
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_tile)); // store the prefetch
};
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
// i+2, No overlap occurs between V and K in the same unroll, and V in current
// unroll and K in next unroll or first unroll in next iteration
store_tile(
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
if constexpr(kHasDropout)
{

View File

@@ -147,7 +147,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
@@ -156,24 +155,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (ElemPerThread / kMinVecLoad);
return kVecLoad;
}
else // Similar to GetAlignmentK()
{
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
return min(ElemPerThread, MaxVectorSize);
}
return kVecLoad;
}
template <typename Problem>
@@ -201,38 +191,14 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
{
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
return N0 * (N1 * kKPerBlock + kKPack);
}
else // similar to GetKSingleSmemElementSpaceSize()
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t kKVector = GetAlignmentV<Problem>();
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
return kKPerBlock * kNPerBlock + kKPerBlock;
}
else
{
static_assert(kKVector % kKPack == 0);
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
};
};
return N0 * (N1 * kKPerBlock + kKPack);
};
template <typename Problem>
@@ -445,202 +411,80 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
{
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// K2 is the vector size for storing shuffled tile to LDS
constexpr index_t K2 = ElemPerThread / N1;
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
constexpr index_t kKPack = GetSmemKPackV<Problem>();
static_assert(kKPack >= K2, "Check failed!");
static_assert(kKPack >= K2, "Check failed!");
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<8>{},
number<1>{});
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<N1 * kKPerBlock + kKPack>{},
number<kKPerBlock>{},
number<1>{}),
number<8>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
else // Similar to MakeKLdsBlockDescriptor()
{
constexpr index_t kKPack = GetSmemKPackV<Problem>();
constexpr index_t kKVector = GetAlignmentV<Problem>();
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
{
static_assert(kKVector == kKPack);
constexpr index_t VSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock;
static_assert(VSingleSmemElementSpaceSize ==
GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize =
GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 =
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKPack + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
}
else
{
static_assert(kKVector % kKPack == 0);
constexpr index_t VSingleSmemElementSpaceSize =
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
static_assert(VSingleSmemElementSpaceSize ==
GetVSingleSmemElementSpaceSize<Problem>());
constexpr index_t SingleSmemElementSpaceSize =
GetSingleSmemElementSpaceSize<Problem>();
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<NumVLdsBuffers>{},
number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kNPerBlock>{},
number<kKPack>{}),
make_tuple(number<SingleSmemElementSpaceSize>{},
number<kNPerBlock * kKVector + kKPack>{},
number<kNPerBlock * kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
v_lds_block_desc_0,
make_tuple(make_merge_transform(
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
number<kKVector / kKPack>{},
number<kKPack>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return v_lds_block_desc;
};
}
return v_lds_block_desc;
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t N1 = GetAlignmentV<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(ElemPerThread % N1 == 0);
static_assert(ElemPerThread % N1 == 0);
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t K2 = ElemPerThread / N1;
constexpr index_t K1 = get_warp_size() / N0;
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
else // Similar to MakeKDramTileDistribution()
{
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
constexpr index_t KPerThread = kMaxVecLoad;
constexpr index_t KThreads = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
sequence<KThreads, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<2>, sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
tuple<sequence<2>, sequence<2, 1>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
{
// This tile-distribuiton only used when V layout is RowMajor
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;

View File

@@ -63,8 +63,7 @@ struct HstuAttentionFwdTileSetting<32>
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
template <>
@@ -75,8 +74,7 @@ struct HstuAttentionFwdTileSetting<64>
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
template <>
@@ -87,8 +85,7 @@ struct HstuAttentionFwdTileSetting<128>
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
template <>
@@ -99,8 +96,7 @@ struct HstuAttentionFwdTileSetting<256>
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile1,
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
#endif
@@ -147,8 +143,7 @@ struct HstuAttentionFwdTileSetting<32>
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
template <>
@@ -159,8 +154,7 @@ struct HstuAttentionFwdTileSetting<64>
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
template <>
@@ -171,8 +165,7 @@ struct HstuAttentionFwdTileSetting<128>
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
template <>
@@ -183,7 +176,6 @@ struct HstuAttentionFwdTileSetting<256>
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
HstuAttentionFwdWarpTile2,
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
HstuAttentionFwdWarpTile1,
IsVLayoutRowMajor>;
HstuAttentionFwdWarpTile1>;
};
#endif

View File

@@ -30,5 +30,3 @@ struct HstuAttentionFwdTypeConfig<ck_tile::bf16_t>
using OaccDataType = GemmAccDataType;
using ODataType = ck_tile::bf16_t;
};
static constexpr bool IsVLayoutRowMajor = true;

View File

@@ -27,8 +27,7 @@ template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
typename Gemm1BlockWarps_,
typename Gemm1WarpTile_,
bool IsVLayoutRowMajor_>
typename Gemm1WarpTile_>
struct HstuAttentionFwdTileSettingClass
{
using BlockTile = remove_cvref_t<BlockTile_>;
@@ -56,12 +55,6 @@ struct HstuAttentionFwdTileSettingClass
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
using VLayout = std::conditional_t<IsVLayoutRowMajor,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
};
} // namespace ck_tile