Remove f8 pipeline, we should share the same pipeline even in f8

This commit is contained in:
rocking
2024-04-08 09:56:23 +00:00
parent f7d81364f3
commit 5c3fdeb0b8
5 changed files with 30 additions and 94 deletions

View File

@@ -49,9 +49,6 @@ auto create_args(int argc, char* argv[])
.insert("d", "128", "head dim for q, k")
.insert("d_v", "0", "head dim for v, 0 means equal to d")
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
.insert("descale_q", "1", "scale factor for fp8 quantization")
.insert("descale_k", "1", "scale factor for fp8 quantization")
.insert("descale_v", "1", "scale factor for fp8 quantization")
.insert("iperm",
"1",
"permute input\n"
@@ -140,10 +137,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
if(scale == .0f)
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
float descale_q = arg_parser.get_float("descale_q");
float descale_k = arg_parser.get_float("descale_k");
float descale_v = arg_parser.get_float("descale_v");
std::string vlayout = arg_parser.get_str("vlayout");
bool use_bias = arg_parser.get_bool("bias");
bool lse = arg_parser.get_bool("lse");
@@ -384,9 +377,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.y,
mask.x,
ck_tile::identity{},
ck_tile::identity{},
descale_q * descale_k,
descale_v};
ck_tile::identity{}};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);

View File

@@ -125,8 +125,6 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
ck_tile::index_t hdim_v,
ck_tile::index_t max_seqlen_q,
float scale,
float descale_qk,
float descale_sv,
bool i_perm,
bool o_perm,
ck_tile::index_t mask_y,
@@ -199,9 +197,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
nhead_stride_lse,
nhead_stride_o,
mask_y,
mask_x,
descale_qk,
descale_sv);
mask_x);
}
else
{ // create batch mode kernel arguments
@@ -235,9 +231,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
batch_stride_lse,
batch_stride_o,
mask_y,
mask_x,
descale_qk,
descale_sv);
mask_x);
}
}();
@@ -266,8 +260,6 @@ struct fmha_fwd_args
ck_tile::index_t hdim_v;
ck_tile::index_t max_seqlen_q;
float scale;
float descale_qk;
float descale_sv;
bool i_perm;
bool o_perm;
ck_tile::index_t mask_y;
@@ -324,8 +316,6 @@ struct fmha_fwd_args
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
typename ElementFunctions::OAccElementFunction o_acc_element_func;
float descale_qk;
float descale_sv;
};
template <typename FmhaKernel, typename FmhaFwdArgs>
@@ -369,9 +359,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
args.descale_sv);
args.o_acc_element_func);
}
else
{ // create batch mode kernel arguments
@@ -413,9 +401,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
// args.lse_element_func,
// args.s_acc_element_func,
args.p_compute_element_func,
args.o_acc_element_func,
args.descale_qk,
args.descale_sv);
args.o_acc_element_func);
}
}();

View File

@@ -42,7 +42,6 @@ LAYOUT_MAP = {
PIPELINE_MAP = {
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
"qr_fp8" : "ck_tile::BlockFmhaPipelineQRKSVSFp8",
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
}
@@ -212,7 +211,7 @@ class FmhaFwdApiTrait:
if self.pipeline_tag == 'qr_async':
if self.spad == 't' : return 'true' # always support
else : return 'true'
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr']:
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@@ -228,7 +227,7 @@ class FmhaFwdApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr']:
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
else : return f'a.hdim_q % {self.bk0blen} == 0'
else: assert False
@@ -239,7 +238,7 @@ class FmhaFwdApiTrait:
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
else : assert False
elif self.pipeline_tag in ['qr', 'qr_fp8']:
elif self.pipeline_tag in ['qr']:
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
else : return f'a.hdim_v % {self.bk0blen} == 0'
else: assert False
@@ -450,7 +449,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
elif dtype in ['fp8', 'bf8']:
# no need lse kernels
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
else:
assert False
return pipelines