mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 11:30:02 +00:00
Remove selectable VLayout for simplifying the codes since hdim is always fatest dimension
This commit is contained in:
@@ -41,8 +41,6 @@ struct HstuAttentionFwdKernel
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::BiasDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::VLayout>;
|
||||
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged;
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
@@ -626,14 +624,8 @@ struct HstuAttentionFwdKernel
|
||||
|
||||
batch_offset_q = query_start * kargs.seq_stride_q;
|
||||
batch_offset_k = key_start * kargs.seq_stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.seq_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
batch_offset_v = key_start * kargs.seq_stride_v;
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.seq_stride_bias;
|
||||
@@ -759,41 +751,24 @@ struct HstuAttentionFwdKernel
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return pad_tensor_view(v_dram_transposed,
|
||||
make_tuple(number<HstuAttentionPipeline::kN1>{},
|
||||
number<HstuAttentionPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.seqlen),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(v_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kN1>{},
|
||||
number<HstuAttentionPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}
|
||||
return pad_tensor_view(v_dram_transposed,
|
||||
make_tuple(number<HstuAttentionPipeline::kN1>{},
|
||||
number<HstuAttentionPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}();
|
||||
|
||||
auto q_dram_window =
|
||||
|
||||
@@ -24,7 +24,6 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
using HstuMask = remove_cvref_t<typename Problem::HstuMask>;
|
||||
|
||||
using HstuAttentionTileSetting = remove_cvref_t<typename Problem::HstuAttentionTileSetting>;
|
||||
using VLayout = remove_cvref_t<typename HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
@@ -54,12 +53,8 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
static constexpr index_t kAlignmentV =
|
||||
Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
@@ -500,27 +495,16 @@ struct HstuAttentionFwdPipelineQRKSVS
|
||||
|
||||
tile_elementwise_inout(f_silu, pcomp_tile);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tile);
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegTileDistribution<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tile);
|
||||
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+2, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unroll in next iteration
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+2, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unroll in next iteration
|
||||
store_tile(v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tile)); // store the prefetch
|
||||
};
|
||||
// if K in this unroll uses Lds-buffer i, then V in this uroll uses Lds-buffer
|
||||
// i+2, No overlap occurs between V and K in the same unroll, and V in current
|
||||
// unroll and K in next unroll or first unroll in next iteration
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 2) % NumKVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
|
||||
@@ -147,7 +147,6 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
@@ -156,24 +155,15 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (ElemPerThread / kMinVecLoad);
|
||||
constexpr index_t kVecLoad = ((ElemPerThread / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (ElemPerThread / kMinVecLoad);
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
else // Similar to GetAlignmentK()
|
||||
{
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(VDataType);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
return kVecLoad;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -201,38 +191,14 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVSingleSmemElementSpaceSize()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t kKPack = GetKVWarpGemmKPerThreadSize<Problem>();
|
||||
|
||||
return N0 * (N1 * kKPerBlock + kKPack);
|
||||
}
|
||||
else // similar to GetKSingleSmemElementSpaceSize()
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentV<Problem>();
|
||||
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
return kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
};
|
||||
};
|
||||
return N0 * (N1 * kKPerBlock + kKPack);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
@@ -445,202 +411,80 @@ struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumKVLdsBuffers<Problem>();
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
// K2 is the vector size for storing shuffled tile to LDS
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
// K2 is the vector size for storing shuffled tile to LDS
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
|
||||
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
// GetSmemKPackV() is the vector size for loading from LDS by BlockGemm
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
static_assert(kKPack >= K2, "Check failed!");
|
||||
static_assert(kKPack >= K2, "Check failed!");
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
|
||||
constexpr index_t VSingleSmemElementSpaceSize = N0 * (N1 * kKPerBlock + kKPack);
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
static_assert(VSingleSmemElementSpaceSize == GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
constexpr index_t SingleSmemElementSpaceSize = GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<N1 * kKPerBlock + kKPack>{},
|
||||
number<kKPerBlock>{},
|
||||
number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<N1 * kKPerBlock + kKPack>{},
|
||||
number<kKPerBlock>{},
|
||||
number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<N0>{}, number<N1>{})),
|
||||
make_pass_through_transform(number<kKPerBlock>{})),
|
||||
make_tuple(sequence<0, 1, 2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
else // Similar to MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentV<Problem>();
|
||||
|
||||
if constexpr(GetKVWarpGemmKPerThreadSize<Problem>() >= 8)
|
||||
{
|
||||
static_assert(kKVector == kKPack);
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize ==
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKPack + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector;
|
||||
|
||||
static_assert(VSingleSmemElementSpaceSize ==
|
||||
GetVSingleSmemElementSpaceSize<Problem>());
|
||||
|
||||
constexpr index_t SingleSmemElementSpaceSize =
|
||||
GetSingleSmemElementSpaceSize<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<SingleSmemElementSpaceSize>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(make_merge_transform(
|
||||
make_tuple(number<NumVLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
};
|
||||
}
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
// Need special consideration for RowMajor since shuffling is needed to write LDS in dwords
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
constexpr index_t K2 = ElemPerThread / N1;
|
||||
constexpr index_t K1 = get_warp_size() / N0;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
else // Similar to MakeKDramTileDistribution()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<2, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledVRegTileDistribution()
|
||||
{
|
||||
// This tile-distribuiton only used when V layout is RowMajor
|
||||
using VLayout = remove_cvref_t<typename Problem::HstuAttentionTileSetting::VLayout>;
|
||||
static_assert(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::HstuAttentionTileSetting::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::HstuAttentionTileSetting::kK1;
|
||||
|
||||
@@ -63,8 +63,7 @@ struct HstuAttentionFwdTileSetting<32>
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -75,8 +74,7 @@ struct HstuAttentionFwdTileSetting<64>
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -87,8 +85,7 @@ struct HstuAttentionFwdTileSetting<128>
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -99,8 +96,7 @@ struct HstuAttentionFwdTileSetting<256>
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -147,8 +143,7 @@ struct HstuAttentionFwdTileSetting<32>
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -159,8 +154,7 @@ struct HstuAttentionFwdTileSetting<64>
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -171,8 +165,7 @@ struct HstuAttentionFwdTileSetting<128>
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
@@ -183,7 +176,6 @@ struct HstuAttentionFwdTileSetting<256>
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile2,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
HstuAttentionFwdWarpTile1>;
|
||||
};
|
||||
#endif
|
||||
|
||||
@@ -30,5 +30,3 @@ struct HstuAttentionFwdTypeConfig<ck_tile::bf16_t>
|
||||
using OaccDataType = GemmAccDataType;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
static constexpr bool IsVLayoutRowMajor = true;
|
||||
|
||||
@@ -27,8 +27,7 @@ template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
typename Gemm1BlockWarps_,
|
||||
typename Gemm1WarpTile_,
|
||||
bool IsVLayoutRowMajor_>
|
||||
typename Gemm1WarpTile_>
|
||||
struct HstuAttentionFwdTileSettingClass
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
@@ -56,12 +55,6 @@ struct HstuAttentionFwdTileSettingClass
|
||||
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
|
||||
|
||||
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
|
||||
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user