mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 20:21:23 +00:00
Remove using partitioner for all fmha kernels (#1778)
* Remove using tile partitioner for fmha_fwd_kernel * Remove using tile partitioner for fmha_fwd_splitkv and splitkv-combine kernels * Remove using tile partitioner for fmha_fwd_appendkv kernel * Unify the format of GetTileIndex
This commit is contained in:
@@ -10,10 +10,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaPipeline_>
|
||||
template <typename FmhaPipeline_>
|
||||
struct FmhaFwdAppendKVKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
|
||||
@@ -234,12 +233,25 @@ struct FmhaFwdAppendKVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_knew)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t seqlen_knew)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, seqlen_knew);
|
||||
// TODO: this may need tuning
|
||||
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, FmhaPipeline::kM0),
|
||||
ck_tile::integer_divide_ceil(seqlen_knew, FmhaPipeline::kN0)),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& /* kargs */)
|
||||
{
|
||||
const index_t i_tile = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -247,7 +259,7 @@ struct FmhaFwdAppendKVKernel
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// divide problem
|
||||
const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}();
|
||||
const auto [i_tile, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
|
||||
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
|
||||
|
||||
@@ -20,10 +20,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
@@ -84,7 +83,7 @@ struct FmhaFwdKernel
|
||||
return n.empty() ? n : std::string("p") + n; }();
|
||||
return
|
||||
_SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
|
||||
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
|
||||
@@ -867,9 +866,75 @@ struct FmhaFwdKernel
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
ck_tile::index_t hdim_v_,
|
||||
bool has_padded_seqlen_k = false)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
|
||||
// has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
|
||||
if(has_padded_seqlen_k)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(nhead_,
|
||||
batch_size_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
bool has_padded_seqlen_k = false;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
|
||||
|
||||
if(has_padded_seqlen_k)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.z;
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_batch = blockIdx.y;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
else
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -885,8 +950,7 @@ struct FmhaFwdKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
|
||||
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -5,10 +5,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdSplitKVCombineKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
|
||||
@@ -235,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -256,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] =
|
||||
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v);
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <index_t kM0_, index_t kN1_>
|
||||
struct FmhaFwdSplitKVCombineTilePartitioner
|
||||
{
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -17,10 +17,9 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename TilePartitioner_, typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
template <typename FmhaPipeline_, typename EpiloguePipeline_>
|
||||
struct FmhaFwdSplitKVKernel
|
||||
{
|
||||
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
|
||||
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
|
||||
@@ -476,13 +475,35 @@ struct FmhaFwdSplitKVKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, FmhaPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
|
||||
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -498,8 +519,7 @@ struct FmhaFwdSplitKVKernel
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] =
|
||||
TilePartitioner{}(kargs.seqlen_q, kargs.hdim_v, kargs.num_splits);
|
||||
const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaFwdSplitKVTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
|
||||
{
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [mn, i_split] = f(blockIdx.x, num_splits);
|
||||
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,105 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaFwdTilePartitioner
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
static constexpr const char* name = "shb";
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
using FmhaFwdTilePartitioner_SHB = FmhaFwdTilePartitioner<BlockFmhaShape_>;
|
||||
|
||||
template <typename BlockFmhaShape_>
|
||||
struct FmhaFwdTilePartitioner_HBS
|
||||
{
|
||||
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
|
||||
|
||||
static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
|
||||
static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
static constexpr const char* name = "hbs";
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(nhead_,
|
||||
batch_size_,
|
||||
ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.z;
|
||||
const index_t i_nhead = blockIdx.x;
|
||||
const index_t i_batch = blockIdx.y;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user