mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 00:40:09 +00:00
[CK_TILE] fmha forward split-kv + combine kernels (#1338)
* FA fwd dropout * FA bwd * epilogue reuse * CMakeLists update * [CK_TILE] support alibi (#1269) * add alibi support * fix code * update code based on comment * Support more hdim * fix fp8 bias * support seqlen_k=0 case * remove unused printf * fix format --------- Co-authored-by: rocking <ChunYu.Lai@amd.com> * now fwd/bwd can build * bwd alibi * add bwd validation stream_config * update generated filenames * update bwd kernel launch * CK_TILE_HOST_DEVICE in philox * Transpose -> transpose * format * format * format * Generate the instance for FA required * format * fix error in WarpGemm * Add num_splits option and dummy split-kv api method * Generate fmha_fwd_splitkv() * Add SplitKV kernel codegen logics * Add SplitKV combine kernel codegen logics * Fix mismatched return type * Clean-up code * Replace sentinel value before storing * Fix wrong layout of LSE/LSEacc/Oacc * Format codes * Fix o_acc memory error * Fix wrong kBlockSize used in policy * Reduce # of combine kernels * Fix split-kv combine kernel name * Fix wrong LDS indexing logics * Fix wrong loop counter step logic * Undo vector size changes * Remove no-longer used field * Remove in-consistent comment * Remove debug statements in example * Remove more debug statements * Add constness to local variables * Clearn up generate.py * Fix unstable clang-format comment * Remove unused include directive * Use shorter template parameter name * Enable non-split-kv blobs * Update license date * Print num_splits conditionally * Undo disabling data types * Remove unnessary tile size for fp8 * Fix wrong pipeline args for fp8 * Fix example output format * Remove more debug code in combine pipeline * Add stride kernel arguments for LSE/O acc workspace * Re-order split-kv pipeline call operator arguments * Pass LSE/O strides in kernel argument * Re-order pipeline call operator arguments * Use tensor_descriptor to locate LSEacc elements * Support providing invalid element for tensor view * Set invalid element value for LSEacc tensor view * Remove hand-written store_tile() code * Remove necessary value-overwrite logic * Add transposed lds descriptor * Support load_tile() for tile_window_with_static_lengths<> * Undo removing necessary value-overwrite logic * Use read descriptor to locate lds elements * Simplify pipeline source code * Add constraint to kMaxSplits * Default use kMaxSplits=64 in generate.py * Revert "Add constraint to kMaxSplits" This reverts commit0a2132d758. * Revert "Default use kMaxSplits=64 in generate.py" This reverts commitc7d9c80b77. * Decide alignment by the padding parameter * Remove no-longer used utility functions * Remove not-working code * Add comment & remove no-longer used code * Fix computation errors * Add heuristic to override num_splits option * Add constraint to kMaxSplits * Fix compilation error * Clean up pipeline code * Wrap pointer access as lambda function * Rename confusing methods * Use kLogMasSplits as template parameter * Finish splitkv combine kernel codegen * Update kMaxSplits limit * Use smaller kM0 for splitkv combine kernel * Ignore droupout flag in splitkv pipeline * Unify flag usage * Add back flag kStoreLSE * Merge lambda calls in pipeline * Fix compilation errors * Avoid all empty splits * Always check for empty loop in splitkv pipelines * Re-order parameters * Remove redundant p_drop option check * Add traits/problem for fwd splitkv kernel * Conditionally enable uneven split boundary checks * Add comment for the splitkv traits field * Change even split criteria * Re-order statements * Refine occupancy value for hdim=128&256 * Refine occupancy value for hdim=32&64 * Remove redundant kernel argument * Separate fmha bwd codegen logics * Separate fmha fwd codegen logics * Remove redundant direction parameter in fwd&bwd codegen logics * Support generate multiple APIs for an example * Let 'api' an alias of 'direction' option * Remove choices for the 'direction' option * Use dictionary to config all the functions * Move fmha splitkv codegen logics to other file * Add fwd_splitkv api for tile_example_fmha_fwd --------- Co-authored-by: danyao12 <danyao12> Co-authored-by: carlushuang <carlus.huang@amd.com> Co-authored-by: rocking <ChunYu.Lai@amd.com> Co-authored-by: Jing Zhang <jizhan@amd.com>
This commit is contained in:
@@ -114,6 +114,9 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("drop_seed", "1", "seed for random number generator")
|
||||
.insert("drop_offset", "0", "offset for random number generator")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("num_splits",
|
||||
"1",
|
||||
"# of splits for key/value. 0 to determine actual number by heuristic")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
@@ -155,6 +158,106 @@ auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
}
|
||||
}
|
||||
|
||||
int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits)
|
||||
{
|
||||
// If we have enough to almost fill the SMs, then just use 1 split
|
||||
if(batch_nhead_mblocks >= 0.8f * num_SMs)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
max_splits = std::min({max_splits, num_SMs, num_n_blocks});
|
||||
float max_efficiency = 0.f;
|
||||
std::vector<float> efficiency;
|
||||
efficiency.reserve(max_splits);
|
||||
auto ceildiv = [](int a, int b) { return (a + b - 1) / b; };
|
||||
// Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits,
|
||||
// we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks
|
||||
// (i.e. it's 11 splits anyway).
|
||||
// So we check if the number of blocks per split is the same as the previous num_splits.
|
||||
auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) {
|
||||
return num_splits == 1 ||
|
||||
ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1);
|
||||
};
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
efficiency.push_back(0.f);
|
||||
}
|
||||
else
|
||||
{
|
||||
float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs;
|
||||
float eff = n_waves / ceil(n_waves);
|
||||
// printf("num_splits = %d, eff = %f\n", num_splits, eff);
|
||||
if(eff > max_efficiency)
|
||||
{
|
||||
max_efficiency = eff;
|
||||
}
|
||||
efficiency.push_back(eff);
|
||||
}
|
||||
}
|
||||
for(int num_splits = 1; num_splits <= max_splits; num_splits++)
|
||||
{
|
||||
if(!is_split_eligible(num_splits))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
if(efficiency[num_splits - 1] >= 0.85 * max_efficiency)
|
||||
{
|
||||
// printf("num_splits chosen = %d\n", num_splits);
|
||||
return num_splits;
|
||||
}
|
||||
}
|
||||
return 1;
|
||||
}
|
||||
|
||||
int override_num_splits_if_necessary(
|
||||
int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits)
|
||||
{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
hipDeviceProp_t props{};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess)
|
||||
{
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
// tile size should match the generate.py
|
||||
const int kM0 = 64;
|
||||
const int kN1 = hdim_v;
|
||||
|
||||
const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0);
|
||||
const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
if(num_splits < 1 && p_drop == 0.0f)
|
||||
{
|
||||
return num_splits_heuristic(
|
||||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128);
|
||||
}
|
||||
|
||||
return num_splits;
|
||||
}
|
||||
|
||||
float fmha_fwd_dispatch(fmha_fwd_traits traits,
|
||||
fmha_fwd_args args,
|
||||
const ck_tile::stream_config& config)
|
||||
{
|
||||
if(1 < args.num_splits)
|
||||
{
|
||||
return fmha_fwd_splitkv(traits, args, config);
|
||||
}
|
||||
else
|
||||
{
|
||||
return fmha_fwd(traits, args, config);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
@@ -260,6 +363,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
int num_splits = arg_parser.get_int("num_splits");
|
||||
|
||||
int stream_warmup = arg_parser.get_int("warmup");
|
||||
int stream_repeat = arg_parser.get_int("repeat");
|
||||
bool kname = arg_parser.get_bool("kname");
|
||||
@@ -320,6 +425,18 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
}
|
||||
|
||||
// legalize num_splits according to other options
|
||||
if(num_splits < 1)
|
||||
{
|
||||
num_splits = override_num_splits_if_necessary(
|
||||
batch, nhead, max_seqlen_q, hdim_v, p_drop, num_splits);
|
||||
}
|
||||
if(128 < num_splits)
|
||||
{
|
||||
std::cerr << "num_splits greater than 128 is not supported" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
auto get_lengths = [&](bool permute,
|
||||
ck_tile::index_t b /*batch*/,
|
||||
ck_tile::index_t h /*nhead*/,
|
||||
@@ -361,7 +478,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: std::array<ck_tile::index_t, 2>{batch, nhead})
|
||||
: std::array<ck_tile::index_t, 2>{1, 1});
|
||||
|
||||
// self define lse data layout as [shape_batch, nhead, shape_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_acc_host(
|
||||
1 < num_splits ? std::array<ck_tile::index_t, 4>{num_splits, batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1});
|
||||
ck_tile::HostTensor<OaccDataType> o_acc_host(
|
||||
1 < num_splits
|
||||
? std::array<ck_tile::index_t, 5>{num_splits, batch, nhead, max_seqlen_q, hdim_v}
|
||||
: std::array<ck_tile::index_t, 5>{1, 1, 1, 1, 1});
|
||||
|
||||
// self define lse data layout as [batch, nhead, max_seqlen_q]
|
||||
ck_tile::HostTensor<LSEDataType> lse_host(
|
||||
lse ? std::array<ck_tile::index_t, 3>{batch, nhead, max_seqlen_q}
|
||||
: std::array<ck_tile::index_t, 3>{1, 1, 1} /* dummy shape for simplifying code */);
|
||||
@@ -443,6 +568,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem bias_buf(bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_acc_buf(lse_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_acc_buf(o_acc_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem lse_buf(lse_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_buf(o_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem seqstart_q(seqstart_q_host.size() * sizeof(int32_t));
|
||||
@@ -479,7 +606,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
: (std::string("(") + std::to_string(seqlen_kpads[0]) + ")"))
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
<< ", mask:" << mask << ", v:" << vlayout;
|
||||
if(1 < num_splits)
|
||||
{
|
||||
std::cout << ", num_splits:" << num_splits;
|
||||
}
|
||||
std::cout << std::flush;
|
||||
|
||||
auto fmha_traits = fmha_fwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
@@ -523,6 +655,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}();
|
||||
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
|
||||
const ck_tile::index_t stride_randval = (max_seqlen_k);
|
||||
const ck_tile::index_t stride_o_acc = hdim_v;
|
||||
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
|
||||
// setup nhead_stride_* arguments
|
||||
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
|
||||
@@ -537,6 +670,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_randval = (shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t nhead_stride_lse = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_lse_acc = max_seqlen_q;
|
||||
const ck_tile::index_t nhead_stride_o_acc = (max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
|
||||
// setup batch_stride_* arguments
|
||||
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
|
||||
@@ -545,7 +680,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_randval = (nhead * shape_seqlen_q * max_seqlen_k);
|
||||
const ck_tile::index_t batch_stride_lse = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_lse_acc = (nhead * max_seqlen_q);
|
||||
const ck_tile::index_t batch_stride_o_acc = (nhead * max_seqlen_q * hdim_v);
|
||||
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
|
||||
// setup split_stride_* arguments (only used in split-kv kernel)
|
||||
const ck_tile::index_t split_stride_lse_acc = (batch * nhead * max_seqlen_q);
|
||||
const ck_tile::index_t split_stride_o_acc = (batch * nhead * max_seqlen_q * hdim_v);
|
||||
|
||||
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
|
||||
k_buf.GetDeviceBuffer(),
|
||||
@@ -553,6 +693,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bias.type == bias_enum::alibi ? alibi_slope_buf.GetDeviceBuffer()
|
||||
: bias_buf.GetDeviceBuffer(),
|
||||
randval_buf.GetDeviceBuffer(),
|
||||
lse_acc_buf.GetDeviceBuffer(),
|
||||
o_acc_buf.GetDeviceBuffer(),
|
||||
lse_buf.GetDeviceBuffer(),
|
||||
o_buf.GetDeviceBuffer(),
|
||||
seqstart_q.GetDeviceBuffer(),
|
||||
@@ -566,6 +708,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
hdim_v,
|
||||
nhead,
|
||||
nhead_k,
|
||||
num_splits,
|
||||
scale_s,
|
||||
scale_p,
|
||||
scale_o,
|
||||
@@ -575,6 +718,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bias.type == bias_enum::alibi ? (bias.rank_info == 0 ? 0 : nhead)
|
||||
: stride_bias,
|
||||
stride_randval,
|
||||
stride_o_acc,
|
||||
stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
@@ -582,6 +726,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
nhead_stride_bias,
|
||||
nhead_stride_randval,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
@@ -589,7 +735,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
batch_stride_bias,
|
||||
batch_stride_randval,
|
||||
batch_stride_lse,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc,
|
||||
mask.left,
|
||||
mask.right,
|
||||
static_cast<ck_tile::index_t>(mask.type),
|
||||
@@ -598,7 +748,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{drop_seed, drop_offset}};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
float ave_time = fmha_fwd_dispatch(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
@@ -849,14 +999,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
lse_host_result.ForEach(
|
||||
[&](auto& self, auto idx) { self(idx) = lse_host(wb, idx[0], idx[1]); });
|
||||
|
||||
bool lse_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
cur_pass = ck_tile::check_err(lse_host_result,
|
||||
lse_host_ref,
|
||||
"LSE Error: Incorrect results!",
|
||||
rtol,
|
||||
atol,
|
||||
/* allow_infinity_ref = */ true);
|
||||
|
||||
pass &= lse_pass;
|
||||
pass &= cur_pass;
|
||||
if(!cur_pass)
|
||||
{
|
||||
std::cerr << "LSE mismatch found at batch: " << wb << std::endl
|
||||
|
||||
Reference in New Issue
Block a user