mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 07:51:52 +00:00
Remove constness from q_ptr
This commit is contained in:
@@ -156,7 +156,7 @@ struct fmha_fwd_args
|
||||
|
||||
struct fmha_fwd_appendkv_args
|
||||
{
|
||||
const void* q_ptr;
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
|
||||
@@ -79,7 +79,7 @@ struct FmhaFwdAppendKVKernel
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct CommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
void* q_ptr;
|
||||
void* k_ptr;
|
||||
const void* knew_ptr;
|
||||
void* v_ptr;
|
||||
@@ -139,7 +139,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
template <bool Cond = !kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
@@ -211,7 +211,7 @@ struct FmhaFwdAppendKVKernel
|
||||
|
||||
template <bool Cond = kIsGroupMode>
|
||||
__host__ static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
MakeKargs(void* q_ptr,
|
||||
void* k_ptr,
|
||||
const void* knew_ptr,
|
||||
void* v_ptr,
|
||||
@@ -384,9 +384,9 @@ struct FmhaFwdAppendKVKernel
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
KDataType* k_ptr =
|
||||
reinterpret_cast<KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
|
||||
@@ -87,7 +87,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
|
||||
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
|
||||
@@ -261,6 +261,28 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
return tile_elementwise_in(vnew_element_func, vnew);
|
||||
}();
|
||||
store_tile(v_dram_block_window, vnew_tile);
|
||||
|
||||
if constexpr(RotaryEnum != BlockRotaryEmbeddingEnum::NONE)
|
||||
{
|
||||
auto q_window = make_tile_window(q_dram_block_window.get_bottom_tensor_view(),
|
||||
q_dram_block_window.get_window_lengths(),
|
||||
q_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeQDramTileDistribution<Problem>());
|
||||
|
||||
auto q_tile = [&]() {
|
||||
auto q = load_tile(q_window);
|
||||
return tile_elementwise_in(q_element_func, q);
|
||||
}();
|
||||
|
||||
// We assume that each thread owns contiguous elements on head dimention. And we will
|
||||
// use the distribution to enable/disable threads in order to override knew_tile content
|
||||
if constexpr(RotaryEnum == BlockRotaryEmbeddingEnum::INTERLEAVED) {}
|
||||
else // RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED
|
||||
{
|
||||
}
|
||||
|
||||
store_tile(q_dram_block_window, q_tile);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindow,
|
||||
@@ -271,7 +293,7 @@ struct BlockFmhaFwdAppendKVPipeline
|
||||
typename RotaryCosDramBlockWindow,
|
||||
typename RotarySinDramBlockWindow>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindow& q_dram_block_window,
|
||||
operator()(QDramBlockWindow& q_dram_block_window,
|
||||
KDramBlockWindow& k_dram_block_window,
|
||||
const KnewDramBlockWindow& knew_dram_block_window,
|
||||
VDramBlockWindow& v_dram_block_window,
|
||||
|
||||
@@ -57,6 +57,40 @@ struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
|
||||
return sizeof(KDataType) * Problem::kTileSizeSk * (Problem::kTileSizeD);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::kTileSizeS;
|
||||
constexpr index_t kKPerBlock = Problem::kTileSizeD;
|
||||
|
||||
constexpr index_t KPerThread = [&]() {
|
||||
if constexpr(Problem::RotaryEnum == BlockRotaryEmbeddingEnum::HALF_ROTATED)
|
||||
{
|
||||
return 8 / sizeof(QDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(QDataType);
|
||||
}
|
||||
}();
|
||||
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreadPerBlock, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user