mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
[CK_TILE] More fmha splitkv optimizations (#1588)
* Use pre-defined constants for readability
* Use vector write for o_acc tensor
* Remove no-longer used policy method
* Deprecate no-longer used policy/pipeline
* Specify gemm0/gemm1 block warps separately in codegen
* Fix wrong ps_idx creation logic
* Add single-warp block gemm
* Supoprt single-warp gemm0
* Make MakeCBlockTile() as static method
* Use MakeCBlockTile() to get underlying tile distribution
* Use kNumGemm1Warps to compute # threads for gemm1
* Put normal case in the if clause
* Refine fmha splitkv block mapping
* Refine & fix the lse_acc/o_acc layout
* Fix wrong LDS size for K tile
* Use kK0=64 for hdim=128,256 fmha splitkv kernels
* Use kK1=64 for hdim=32,64,128 fmha splitkv kernels
* Undo kK0/kK1 changes
* Use more reasonable GetAlignmentV() computation
* Using store_tile() in fmha splitkv kernel epilogue
[ROCm/composable_kernel commit: 54f0e6f4bb]
This commit is contained in:
@@ -36,13 +36,12 @@ FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
||||
using fmha_block_warps_{F_idx} = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
||||
using fmha_warp_tile_{F_idx} = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
fmha_warp_tile_{F_idx},
|
||||
fmha_block_warps_{F_idx},
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
fmha_warp_tile_{F_idx},
|
||||
{F_vlayout}>;
|
||||
|
||||
@@ -291,9 +290,12 @@ class FmhaFwdTileSize:
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0blen : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm : int # number of warps along q seqlen (block warps)
|
||||
F_rn : int # number of warps along k seqlen(not used)
|
||||
F_rk : int # number of warps along gemm-k(not used)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm : int # warp size along m (warp size)
|
||||
F_wn : int # warp size along n
|
||||
F_wk : int # warp size along k
|
||||
@@ -301,8 +303,8 @@ class FmhaFwdTileSize:
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0blen}" +\
|
||||
f"_r{self.F_rm}x{self.F_rn}x{self.F_rk}_w{self.F_wm}x{self.F_wn}x{self.F_wk}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm}x{self.F_wn}x{self.F_wk}" + ("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
@@ -334,9 +336,12 @@ class FmhaFwdKernel:
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
@@ -394,16 +399,16 @@ class FmhaFwdKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 16, -1),
|
||||
'32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, -1),
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -42,13 +42,12 @@ namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
struct kernel_runner {{
|
||||
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0blen}>;
|
||||
using fmha_block_warps = ck_tile::sequence<{F_rm}, {F_rn}, {F_rk}>;
|
||||
using fmha_warp_tile = ck_tile::sequence<{F_wm}, {F_wn}, {F_wk}>;
|
||||
|
||||
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
|
||||
fmha_block_warps,
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
fmha_warp_tile,
|
||||
fmha_block_warps,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
fmha_warp_tile,
|
||||
{F_vlayout}>;
|
||||
|
||||
@@ -162,10 +161,12 @@ using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
|
||||
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
|
||||
fmha_pipeline_problem>;
|
||||
|
||||
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
|
||||
/// store_tile_raw() data corruption issue
|
||||
using fmha_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
false, false>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVCombineKernel<ck_tile::FmhaFwdSplitKVCombineTilePartitioner<{F_bm0}, {F_bn1}>,
|
||||
@@ -458,9 +459,12 @@ class FmhaFwdSplitKVKernel:
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0blen = self.F_tile.F_bk0blen,
|
||||
F_rm = self.F_tile.F_rm,
|
||||
F_rn = self.F_tile.F_rn,
|
||||
F_rk = self.F_tile.F_rk,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm = self.F_tile.F_wm,
|
||||
F_wn = self.F_tile.F_wn,
|
||||
F_wk = self.F_tile.F_wk,
|
||||
@@ -553,16 +557,16 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 16, 16, 16, -1),
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 16, 16, 16, -1),
|
||||
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, -1),
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 32, 32, 32, -1)
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, -1)
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@@ -557,33 +557,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
#endif
|
||||
|
||||
struct
|
||||
{
|
||||
auto operator()(bool permute,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/)
|
||||
{
|
||||
if(permute)
|
||||
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
||||
else
|
||||
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
||||
}
|
||||
|
||||
auto operator()(bool permute,
|
||||
ck_tile::index_t ns /*num_splits*/,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/)
|
||||
{
|
||||
if(permute)
|
||||
return std::array<ck_tile::index_t, 5>{ns, b, h, s, d};
|
||||
else
|
||||
return std::array<ck_tile::index_t, 5>{ns, b, s, h, d};
|
||||
}
|
||||
} get_lengths;
|
||||
static const auto get_lengths = [](bool permute,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
ck_tile::index_t s /*seqlen*/,
|
||||
ck_tile::index_t d /*hdim*/) {
|
||||
if(permute)
|
||||
return std::array<ck_tile::index_t, 4>{b, h, s, d};
|
||||
else
|
||||
return std::array<ck_tile::index_t, 4>{b, s, h, d};
|
||||
};
|
||||
|
||||
bool is_v_rowmajor = vlayout == std::string("r");
|
||||
|
||||
@@ -635,12 +618,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits || use_kvcache
|
||||
? std::array<ck_tile::index_t, 4>{num_splits, shape_batch, nhead, shape_seqlen_q}
|
||||
? std::array<ck_tile::index_t, 4>{shape_batch, nhead, num_splits, shape_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits || use_kvcache
|
||||
? get_lengths(o_perm, num_splits, shape_batch, nhead, shape_seqlen_q, hdim_v)
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
1 < num_splits || use_kvcache ? std::array<ck_tile::index_t, 5>{shape_batch,
|
||||
nhead,
|
||||
num_splits,
|
||||
shape_seqlen_q,
|
||||
hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// batch mode of lse data layout is [batch, nhead, seqlen_q]
|
||||
// group mode of lse data layout is [nhead, total_seqlen_q]
|
||||
@@ -880,7 +866,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o_acc = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
const ck_tile::index_t stride_o_acc = (hdim_v);
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
@@ -906,8 +892,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = shape_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o_acc = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
const ck_tile::index_t nhead_stride_lse_acc = (num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t nhead_stride_o_acc = (num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
@@ -922,13 +908,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * num_splits * shape_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * num_splits * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_block_table = (max_num_page_blocks / batch);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_batch * nhead * shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_batch * nhead * shape_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t split_stride_lse_acc = (shape_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (shape_seqlen_q * hdim_v);
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
|
||||
@@ -69,7 +69,8 @@ struct FmhaFwdKernel
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using gbr = typename bfs::Gemm0BlockWarps;
|
||||
using g0br = typename bfs::Gemm0BlockWarps;
|
||||
using g1br = typename bfs::Gemm1BlockWarps;
|
||||
using gwt = typename bfs::Gemm0WarpTile;
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
@@ -85,7 +86,8 @@ struct FmhaFwdKernel
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_" + _SS_(TilePartitioner::name) + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
|
||||
@@ -65,7 +65,8 @@ struct FmhaFwdSplitKVKernel
|
||||
// sync with generate.py
|
||||
// clang-format off
|
||||
using bfs = typename FmhaPipeline::BlockFmhaShape;
|
||||
using gbr = typename bfs::Gemm0BlockWarps;
|
||||
using g0br = typename bfs::Gemm0BlockWarps;
|
||||
using g1br = typename bfs::Gemm1BlockWarps;
|
||||
using gwt = typename bfs::Gemm0WarpTile;
|
||||
#define _SS_ std::string
|
||||
#define _TS_ std::to_string
|
||||
@@ -81,7 +82,8 @@ struct FmhaFwdSplitKVKernel
|
||||
"_" + (kIsGroupMode ? "group" : "batch") + "_"
|
||||
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
|
||||
_TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK0BlockLength) + "_" +
|
||||
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
|
||||
"r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
|
||||
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
|
||||
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
|
||||
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
|
||||
@@ -894,7 +896,7 @@ struct FmhaFwdSplitKVKernel
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_o_acc, 1),
|
||||
number<1>{},
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
|
||||
@@ -26,8 +26,8 @@ struct FmhaFwdSplitKVTilePartitioner
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead * num_splits,
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1) * num_splits,
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
@@ -42,8 +42,9 @@ struct FmhaFwdSplitKVTilePartitioner
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(blockIdx.x, num_tile_n1);
|
||||
const auto [i_nhead, i_split] = f(blockIdx.y, num_splits);
|
||||
const auto [mn, i_split] = f(blockIdx.x, num_splits);
|
||||
const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
|
||||
|
||||
@@ -64,6 +64,9 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentOacc =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentOacc<Problem>();
|
||||
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -252,11 +255,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
|
||||
v_dram_block_window_lengths,
|
||||
|
||||
@@ -9,11 +9,20 @@
|
||||
namespace ck_tile {
|
||||
|
||||
// This pipeline is qkv all located in LDS
|
||||
using BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy =
|
||||
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopyK = */ false,
|
||||
/* AsyncCopyV = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>;
|
||||
struct BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopyK = */ false,
|
||||
/* AsyncCopyV = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
return static_cast<index_t>(16 / sizeof(OaccDataType));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -39,8 +39,11 @@ struct BlockFmhaPipelineProblem
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
@@ -84,8 +87,11 @@ struct BlockFmhaFwdSplitKVPipelineProblem
|
||||
using FmhaMask = remove_cvref_t<FmhaMask_>;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
|
||||
@@ -242,11 +242,11 @@ struct BlockFmhaPipelineQRKSVS
|
||||
{seqlen_k_start, 0});
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
@@ -314,11 +314,11 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
}();
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window = make_tile_window(
|
||||
bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
|
||||
randval_dram_block_window_tmp, seqlen_k_start);
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// NOTICE: we no-longer use this pipeline.
|
||||
// This pipeline is qkv all located in LDS
|
||||
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQSKSVSDefaultPolicy>
|
||||
struct BlockFmhaPipelineQSKSVS
|
||||
struct [[deprecated]] BlockFmhaPipelineQSKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
||||
|
||||
// TODO: remove this
|
||||
#define K_LDS_LOAD_USE_OFFSET_TRANSFORM 0
|
||||
@@ -64,13 +65,28 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(MWarp == 1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2>>,
|
||||
tuple<sequence<2, 1>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -80,7 +96,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
@@ -129,12 +145,16 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
/// NOTICE: we no-longer use this policy.
|
||||
template <>
|
||||
struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
struct [[deprecated]] BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
|
||||
{
|
||||
static constexpr bool QLoadOnce = false;
|
||||
|
||||
@@ -364,12 +384,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
// TODO: not correct!
|
||||
if constexpr(total_pixels > 4)
|
||||
return 4;
|
||||
else
|
||||
return 2;
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: (total_pixels / kMinVecLoad);
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -383,10 +406,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using BlockGemm = remove_cvref_t<decltype(QXPolicy::template GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
constexpr auto vec =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
|
||||
return vec;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -395,10 +416,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
constexpr auto vec =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
|
||||
return vec;
|
||||
|
||||
return WG::WarpGemmAttribute::Impl::kCM1PerLane;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -449,44 +468,12 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
return max(SingleKSize, SingleVSize);
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0BlockLength;
|
||||
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
|
||||
|
||||
constexpr auto q_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
q_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
||||
|
||||
return q_block_dstr;
|
||||
}
|
||||
|
||||
// TODO: this is used for non async copy desc. unify in the future
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
@@ -886,36 +873,10 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, typename BlockGemm>
|
||||
template <typename BlockGemm>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasDramTileDistribution()
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t NPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
|
||||
// Construct C-Block-HostTensor
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
return c_block_dstr;
|
||||
return BlockGemm::MakeCBlockTile().get_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -972,7 +933,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
Problem::kNumGemm1Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
|
||||
@@ -21,10 +21,15 @@ struct TileFmhaShape
|
||||
using Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>;
|
||||
using Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
static constexpr index_t NumGemm0Warps =
|
||||
reduce_on_sequence(Gemm0BlockWarps{}, multiplies{}, number<1>{});
|
||||
static constexpr index_t NumGemm1Warps =
|
||||
reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{});
|
||||
static_assert(NumGemm1Warps % NumGemm0Warps == 0);
|
||||
|
||||
static_assert(NumWarps == reduce_on_sequence(Gemm1BlockWarps{}, multiplies{}, number<1>{}));
|
||||
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps);
|
||||
|
||||
static_assert(std::is_same_v<Gemm0WarpTile, Gemm1WarpTile>);
|
||||
|
||||
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
|
||||
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_one_warp_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
@@ -157,7 +157,7 @@ struct BlockGemmARegBRegCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -0,0 +1,237 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A is block distributed tensor
|
||||
// B is block window on shared memory
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_ = BlockGemmARegBSmemCRegV1DefaultPolicy>
|
||||
struct BlockGemmARegBSmemCRegOneWarpV1
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static_assert(kBlockSize == get_warp_size(), "Check failed!");
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
|
||||
// KPerBlock == BlockGemmShape::kK,
|
||||
// "wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp;
|
||||
constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp;
|
||||
|
||||
const index_t iNWarp = 0;
|
||||
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto a_block_dstr = make_static_tile_distribution(a_block_dstr_encode);
|
||||
|
||||
// constrcut from A-block-tensor from A-Block-tensor-tmp
|
||||
// FIXME: need method to check a_block_tensor and a_block_tensor_tmp have equivalent
|
||||
// distribution
|
||||
auto a_block_tensor =
|
||||
make_static_distributed_tensor<typename ABlockTensorTmp::DataType>(a_block_dstr);
|
||||
|
||||
a_block_tensor.get_thread_buffer() = a_block_tensor_tmp.get_thread_buffer();
|
||||
|
||||
// construct B-warp-window
|
||||
auto b_warp_window_tmp = make_tile_window(
|
||||
b_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(number<WG::kN>{}, number<WG::kK>{}),
|
||||
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
|
||||
make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
|
||||
|
||||
#if 0 // FIXME: using array will cause register spill
|
||||
array<array<decltype(b_warp_window_tmp), KIterPerWarp>, NIterPerWarp> b_warp_windows{
|
||||
{b_warp_window_tmp}};
|
||||
|
||||
for(index_t nIter = 0; nIter < NIterPerWarp; nIter++)
|
||||
{
|
||||
for(index_t kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
}
|
||||
}
|
||||
#else
|
||||
statically_indexed_array<
|
||||
statically_indexed_array<decltype(b_warp_window_tmp), KIterPerWarp>,
|
||||
NIterPerWarp>
|
||||
b_warp_windows;
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
|
||||
|
||||
move_tile_window(b_warp_windows(nIter)(kIter),
|
||||
{nIter * NPerBlockPerIter, kIter * KPerBlockPerIter});
|
||||
});
|
||||
});
|
||||
#endif
|
||||
|
||||
// check C-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"wrong!");
|
||||
|
||||
using AWarpDstr = typename WG::AWarpDstr;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WG::AWarpTensor;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A block tensor
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B Block window
|
||||
const auto b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static_assert(MWarp == 1 && NWarp == 1, "Check failed!");
|
||||
|
||||
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
|
||||
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
|
||||
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
|
||||
|
||||
constexpr auto c_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
static_assert(decltype(c_block_dstr_encode)::NDimP == 1, "Check failed!");
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensorTmp, typename BBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BBlockWindowTmp& b_block_window_tmp) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -181,7 +181,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -182,7 +182,7 @@ struct BlockGemmARegBSmemCRegV2
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -180,7 +180,7 @@ struct BlockGemmASmemBRegCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -167,7 +167,7 @@ struct BlockGemmASmemBSmemCRegV1
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
|
||||
@@ -22,7 +22,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
const auto ps_idx = make_array<index_t>(get_block_id(), get_lane_id());
|
||||
const auto ps_idx = detail::get_partition_index(acc_tensor.get_tile_distribution());
|
||||
const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
|
||||
|
||||
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
Reference in New Issue
Block a user