Remove f8 pipeline, we should share the same pipeline even in f8

This commit is contained in:
rocking
2024-04-08 09:56:23 +00:00
parent f7d81364f3
commit 5c3fdeb0b8
5 changed files with 30 additions and 94 deletions

View File

@@ -49,9 +49,6 @@ auto create_args(int argc, char* argv[])
.insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
.insert("descale_q", "1", "scale factor for fp8 quantization")
.insert("descale_k", "1", "scale factor for fp8 quantization")
.insert("descale_v", "1", "scale factor for fp8 quantization")
.insert("iperm",
"1",
"permute input\n"
@@ -140,10 +137,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(scale == .0f)
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
float descale_q = arg_parser.get_float("descale_q");
float descale_k = arg_parser.get_float("descale_k");
float descale_v = arg_parser.get_float("descale_v");
std::string vlayout = arg_parser.get_str("vlayout");
bool use_bias = arg_parser.get_bool("bias");
bool lse = arg_parser.get_bool("lse");
@@ -384,9 +377,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.y,
mask.x,
ck_tile::identity{},
ck_tile::identity{},
descale_q * descale_k,
descale_v};
ck_tile::identity{}};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);

View File

@@ -125,8 +125,6 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t max_seqlen_q,
float scale,
float descale_qk,
float descale_sv,
bool i_perm,
bool o_perm,
ck_tile::index_t mask_y,
@@ -199,9 +197,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
nhead_stride_lse,
nhead_stride_o,
mask_y,
mask_x,
descale_qk,
descale_sv);
mask_x);
}
else
{ // create batch mode kernel arguments
@@ -235,9 +231,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
batch_stride_lse,
batch_stride_o,
mask_y,
mask_x,
descale_qk,
descale_sv);
mask_x);
}
}();
@@ -266,8 +260,6 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v;
ck_tile::index_t max_seqlen_q;
float scale;
float descale_qk;
float descale_sv;
bool i_perm;
bool o_perm;
ck_tile::index_t mask_y;
@@ -324,8 +316,6 @@ struct fmha_fwd_args
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
typename ElementFunctions::OAccElementFunction o_acc_element_func;
float descale_qk;
float descale_sv;
};
template <typename FmhaKernel, typename FmhaFwdArgs>
@@ -369,9 +359,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
args.descale_sv);
args.o_acc_element_func);
}
else
{ // create batch mode kernel arguments
@@ -413,9 +401,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
args.descale_sv);
args.o_acc_element_func);
}
}();

View File

@@ -42,7 +42,6 @@ LAYOUT_MAP = {
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_fp8" : "ck_tile::BlockFmhaPipelineQRKSVSFp8",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
@@ -212,7 +211,7 @@ class FmhaFwdApiTrait:
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@@ -228,7 +227,7 @@ class FmhaFwdApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@@ -239,7 +238,7 @@ class FmhaFwdApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@@ -450,7 +449,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
else:
assert False
return pipelines

View File

@@ -166,13 +166,6 @@ struct FmhaFwdKernel
ck_tile::index_t mask_y, mask_x;
};
struct FmhaFwdFP8Kargs
{
float descale_qk; // q*k
float descale_sv; // s*v
// float * o_amax_ptr;
};
struct FmhaFwdCommonLSEKargs
{
void* lse_ptr = nullptr;
@@ -188,8 +181,7 @@ struct FmhaFwdKernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdBatchModeBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kIsFp8, FmhaFwdFP8Kargs, FmhaFwdEmptyKargs<3>>
std::conditional_t<kStoreLSE, FmhaFwdBatchModeLSEKargs, FmhaFwdEmptyKargs<2>>
{
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
@@ -201,8 +193,7 @@ struct FmhaFwdKernel
: FmhaFwdCommonKargs,
std::conditional_t<kHasBias, FmhaFwdCommonBiasKargs, FmhaFwdEmptyKargs<0>>,
std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
std::conditional_t<kIsFp8, FmhaFwdFP8Kargs, FmhaFwdEmptyKargs<3>>
std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
@@ -251,9 +242,7 @@ struct FmhaFwdKernel
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
float descale_sv)
OAccElementFunction o_acc_element_func)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -288,7 +277,6 @@ struct FmhaFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
batch_stride_q,
batch_stride_k,
batch_stride_v,
@@ -312,11 +300,6 @@ struct FmhaFwdKernel
kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
}
if constexpr(kIsFp8)
{
kargs.descale_qk = descale_qk;
kargs.descale_sv = descale_sv;
}
return kargs;
}
@@ -356,9 +339,7 @@ struct FmhaFwdKernel
// LSEElementFunction lse_element_func,
// SAccElementFunction s_acc_element_func,
PComputeElementFunction p_compute_element_func,
OAccElementFunction o_acc_element_func,
float descale_qk,
float descale_sv)
OAccElementFunction o_acc_element_func)
{
Kargs kargs{{q_ptr,
k_ptr,
@@ -393,7 +374,6 @@ struct FmhaFwdKernel
{}, // placeholder for bias
{}, // placeholder for mask
{}, // placeholder for lse
{}, // placeholder for fp8 args
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
@@ -414,11 +394,6 @@ struct FmhaFwdKernel
kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse;
}
if constexpr(kIsFp8)
{
kargs.descale_qk = descale_qk;
kargs.descale_sv = descale_sv;
}
return kargs;
}
@@ -702,38 +677,22 @@ struct FmhaFwdKernel
}();
auto o_acc_tile = [&]() {
if constexpr(kIsFp8)
{
return FmhaPipeline{}(q_dram_window,
k_dram_window,
v_dram_window,
bias_dram_window,
lse_dram_window,
mask,
kargs.scale,
kargs.descale_qk,
kargs.descale_sv,
smem_ptr);
}
else
{
return FmhaPipeline{}(q_dram_window,
identity{},
k_dram_window,
identity{},
v_dram_window,
identity{},
bias_dram_window,
identity{},
lse_dram_window,
identity{},
identity{},
kargs.p_compute_element_func,
kargs.o_acc_element_func,
mask,
kargs.scale,
smem_ptr);
}
return FmhaPipeline{}(q_dram_window,
identity{},
k_dram_window,
identity{},
v_dram_window,
identity{},
bias_dram_window,
identity{},
lse_dram_window,
identity{},
identity{},
kargs.p_compute_element_func,
kargs.o_acc_element_func,
mask,
kargs.scale,
smem_ptr);
}();
// O DRAM and O DRAM window

View File

@@ -9,6 +9,7 @@
namespace ck_tile {
// deprecated pipeline
// This pipeline is qkv all located in LDS
template <typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSDefaultPolicy>
struct BlockFmhaPipelineQRKSVSFp8