Remove constness from q_ptr

This commit is contained in:
PoYen, Chen
2024-07-23 03:11:31 +00:00
parent c26c60db4c
commit 1dbed18555
4 changed files with 65 additions and 9 deletions

View File

@@ -79,7 +79,7 @@ struct FmhaFwdAppendKVKernel
// user need to use MakeKargs() function to create kargs.
struct CommonKargs
{
const void* q_ptr;
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
@@ -139,7 +139,7 @@ struct FmhaFwdAppendKVKernel
template <bool Cond = !kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
@@ -211,7 +211,7 @@ struct FmhaFwdAppendKVKernel
template <bool Cond = kIsGroupMode>
__host__ static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* q_ptr,
MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
@@ -384,9 +384,9 @@ struct FmhaFwdAppendKVKernel
}
// for simplicity, batch stride we just modify the pointer
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
KDataType* k_ptr =
reinterpret_cast<KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +

View File

@@ -87,7 +87,7 @@ struct BlockFmhaFwdAppendKVPipeline
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindow& q_dram_block_window, // M0*K0 tile
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
@@ -261,6 +261,28 @@ struct BlockFmhaFwdAppendKVPipeline
return tile_elementwise_in(vnew_element_func, vnew);
}();
store_tile(v_dram_block_window, vnew_tile);
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
{
auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>());
auto q_tile = [&]() {
auto q = load_tile(q_window);
return tile_elementwise_in(q_element_func, q);
}();
// We assume that each thread owns contiguous elements on head dimention. And we will
// use the distribution to enable/disable threads in order to override knew_tile content
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
{
}
store_tile(q_dram_block_window, q_tile);
}
}
template <typename QDramBlockWindow,
@@ -271,7 +293,7 @@ struct BlockFmhaFwdAppendKVPipeline
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindow& q_dram_block_window,
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,

View File

@@ -57,6 +57,40 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kTileSizeS;
constexpr index_t kKPerBlock = Problem::kTileSizeD;
constexpr index_t KPerThread = [&]() {
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
{
return 8 / sizeof(QDataType);
}
else
{
return 16 / sizeof(QDataType);
}
}();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
{