mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
now fwd/bwd can build
This commit is contained in:
@@ -67,3 +67,9 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS})
|
||||
target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
|
||||
@@ -57,7 +57,11 @@ auto create_args(int argc, char* argv[])
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "1", "permute output")
|
||||
.insert("bias", "0", "add bias or not")
|
||||
.insert("bias",
|
||||
"n",
|
||||
"n or 0, no bias\n"
|
||||
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
|
||||
"a(libi) or 2, alibi with 1*h. a:1, b*h")
|
||||
.insert("dbias", "0", "output bias gradient or not")
|
||||
.insert("prec", "fp16", "data type. fp16 or bf16")
|
||||
.insert("mask",
|
||||
@@ -136,12 +140,12 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale == .0f)
|
||||
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q));
|
||||
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bias_info bias = bias_info::decode(arg_parser.get_str("bias"));
|
||||
bool use_dbias = arg_parser.get_bool("dbias");
|
||||
float p_drop = arg_parser.get_float("p_drop");
|
||||
uint64_t drop_seed = arg_parser.get_uint64("drop_seed");
|
||||
uint64_t drop_offset = arg_parser.get_uint64("drop_offset");
|
||||
if(use_dbias && !use_bias)
|
||||
if(use_dbias && bias.type != bias_enum::elementwise_bias)
|
||||
{
|
||||
std::cerr << "dbias only exists when there is a bias" << std::endl;
|
||||
return false;
|
||||
@@ -266,7 +270,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host
|
||||
// will not be used for verification at all (but will be copied to device anyway).
|
||||
ck_tile::HostTensor<BiasDataType> bias_host(
|
||||
use_bias
|
||||
bias.type == bias_enum::elementwise_bias
|
||||
? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k)
|
||||
: std::array<ck_tile::index_t, 4>{1, 1, 1, 1} /* dummy shape for simplifying code */);
|
||||
ck_tile::HostTensor<ODataType> o_host(
|
||||
@@ -354,7 +358,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias
|
||||
<< ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask
|
||||
<< std::flush;
|
||||
|
||||
@@ -363,7 +367,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
data_type,
|
||||
mode == mode_enum::group,
|
||||
mask.type,
|
||||
use_bias,
|
||||
bias.type,
|
||||
use_dbias,
|
||||
p_drop > 0.0f};
|
||||
auto fmha_args = [&]() {
|
||||
@@ -556,7 +560,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k
|
||||
|
||||
if(use_bias)
|
||||
if(bias.type == bias_enum::elementwise_bias)
|
||||
{
|
||||
// clang-format off
|
||||
ck_tile::HostTensor<BiasDataType> bias_host_ref({1, real_seqlen_q, real_seqlen_k});
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "mask.hpp"
|
||||
#include "bias.hpp"
|
||||
#include <type_traits>
|
||||
|
||||
template <typename DataType>
|
||||
@@ -291,7 +292,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kIsGroupMode_,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum FmhaBwdPipelineEnum_,
|
||||
typename FmhaMask_,
|
||||
bool kHasBias_,
|
||||
ck_tile::BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kHasDropout_,
|
||||
bool kPadS_,
|
||||
@@ -305,7 +306,7 @@ struct fmha_bwd_dq_dk_dv_traits_
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr auto FmhaBwdPipelineEnum = FmhaBwdPipelineEnum_;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<FmhaMask_>;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kPadS = kPadS_;
|
||||
@@ -338,7 +339,7 @@ struct fmha_bwd_traits
|
||||
std::string data_type;
|
||||
bool is_group_mode;
|
||||
mask_enum mask_type;
|
||||
bool has_bias;
|
||||
bias_enum bias_type;
|
||||
bool has_dbias;
|
||||
bool has_dropout;
|
||||
// TODO: padding check is inside this api
|
||||
|
||||
@@ -204,7 +204,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
// scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)]
|
||||
scale_o = range_p * range_v / range_o / dtype_max;
|
||||
}
|
||||
|
||||
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
|
||||
@@ -424,8 +424,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch
|
||||
<< ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k
|
||||
<< ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout
|
||||
<< std::flush;
|
||||
<< ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant
|
||||
<< ", mask:" << mask << ", v:" << vlayout << std::flush;
|
||||
|
||||
auto fmha_traits = fmha_fwd_traits{hdim_q,
|
||||
hdim_v,
|
||||
|
||||
@@ -91,7 +91,7 @@ struct fmha_fwd_args
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
const void* bias_ptr; // bias or alibi_slope pointer
|
||||
void* rand_val_ptr;
|
||||
void* lse_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
@@ -681,7 +681,7 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>;
|
||||
@@ -772,7 +772,7 @@ class FmhaBwdApiPool:
|
||||
if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")):
|
||||
continue
|
||||
inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
||||
|
||||
@@ -868,7 +868,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
F_skpad = BOOL_MAP[self.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_bias = BOOL_MAP[self.F_bias],
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = BOOL_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
@@ -890,7 +890,7 @@ class FmhaBwdDQDKDVKernel:
|
||||
mn = mask_name()
|
||||
n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\
|
||||
f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\
|
||||
f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
|
||||
(f'_{self.F_bias}' if self.F_bias != 'no' else '') + f"_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}"
|
||||
if mn != '' : n += f'{mn}'
|
||||
return n
|
||||
|
||||
@@ -944,13 +944,13 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[Fm
|
||||
d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
tile = d[hdim_str][0]
|
||||
ppl = d[hdim_str][1]
|
||||
hdim = int(hdim_str)
|
||||
if (mode == "group") and (spad == "f" or skpad == "f"):
|
||||
continue
|
||||
if (bias == "f" and dbias == "t"):
|
||||
if (bias == "no" and dbias == "t"):
|
||||
continue
|
||||
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
|
||||
Reference in New Issue
Block a user