mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Remove f8 pipeline, we should share the same pipeline even in f8
This commit is contained in:
@@ -49,9 +49,6 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("d", "128", "head dim for q, k")
|
||||
.insert("d_v", "0", "head dim for v, 0 means equal to d")
|
||||
.insert("scale", "0", "scale factor. 0 means equal to 1/sqrt(hdim)")
|
||||
.insert("descale_q", "1", "scale factor for fp8 quantization")
|
||||
.insert("descale_k", "1", "scale factor for fp8 quantization")
|
||||
.insert("descale_v", "1", "scale factor for fp8 quantization")
|
||||
.insert("iperm",
|
||||
"1",
|
||||
"permute input\n"
|
||||
@@ -140,10 +137,6 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
if(scale == .0f)
|
||||
scale = 1.0 / ck_tile::sqrt(static_cast<float>(hdim_q)); // TODO: q ? v ?
|
||||
|
||||
float descale_q = arg_parser.get_float("descale_q");
|
||||
float descale_k = arg_parser.get_float("descale_k");
|
||||
float descale_v = arg_parser.get_float("descale_v");
|
||||
|
||||
std::string vlayout = arg_parser.get_str("vlayout");
|
||||
bool use_bias = arg_parser.get_bool("bias");
|
||||
bool lse = arg_parser.get_bool("lse");
|
||||
@@ -384,9 +377,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
mask.y,
|
||||
mask.x,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
descale_q * descale_k,
|
||||
descale_v};
|
||||
ck_tile::identity{}};
|
||||
}();
|
||||
|
||||
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);
|
||||
|
||||
@@ -125,8 +125,6 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
float scale,
|
||||
float descale_qk,
|
||||
float descale_sv,
|
||||
bool i_perm,
|
||||
bool o_perm,
|
||||
ck_tile::index_t mask_y,
|
||||
@@ -199,9 +197,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
|
||||
nhead_stride_lse,
|
||||
nhead_stride_o,
|
||||
mask_y,
|
||||
mask_x,
|
||||
descale_qk,
|
||||
descale_sv);
|
||||
mask_x);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -235,9 +231,7 @@ auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
|
||||
batch_stride_lse,
|
||||
batch_stride_o,
|
||||
mask_y,
|
||||
mask_x,
|
||||
descale_qk,
|
||||
descale_sv);
|
||||
mask_x);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -266,8 +260,6 @@ struct fmha_fwd_args
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
float scale;
|
||||
float descale_qk;
|
||||
float descale_sv;
|
||||
bool i_perm;
|
||||
bool o_perm;
|
||||
ck_tile::index_t mask_y;
|
||||
@@ -324,8 +316,6 @@ struct fmha_fwd_args
|
||||
// typename ElementFunctions::SAccElementFunction s_acc_element_func;
|
||||
typename ElementFunctions::PComputeElementFunction p_compute_element_func;
|
||||
typename ElementFunctions::OAccElementFunction o_acc_element_func;
|
||||
float descale_qk;
|
||||
float descale_sv;
|
||||
};
|
||||
|
||||
template <typename FmhaKernel, typename FmhaFwdArgs>
|
||||
@@ -369,9 +359,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
// args.lse_element_func,
|
||||
// args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
args.descale_sv);
|
||||
args.o_acc_element_func);
|
||||
}
|
||||
else
|
||||
{ // create batch mode kernel arguments
|
||||
@@ -413,9 +401,7 @@ auto fmha_fwd_create_kargs_and_grids(FmhaFwdArgs args)
|
||||
// args.lse_element_func,
|
||||
// args.s_acc_element_func,
|
||||
args.p_compute_element_func,
|
||||
args.o_acc_element_func,
|
||||
args.descale_qk,
|
||||
args.descale_sv);
|
||||
args.o_acc_element_func);
|
||||
}
|
||||
}();
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ LAYOUT_MAP = {
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_fp8" : "ck_tile::BlockFmhaPipelineQRKSVSFp8",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
@@ -212,7 +211,7 @@ class FmhaFwdApiTrait:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.spad == 't' : return f'a.seqlen_q % {self.bm0} != 0'
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
@@ -228,7 +227,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.dpad == 't': return f'a.hdim_q % {self.bk0blen} != 0'
|
||||
else : return f'a.hdim_q % {self.bk0blen} == 0'
|
||||
else: assert False
|
||||
@@ -239,7 +238,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.dvpad == 't': return f'a.hdim_v % {self.bk0blen} != 0'
|
||||
else : return f'a.hdim_v % {self.bk0blen} == 0'
|
||||
else: assert False
|
||||
@@ -450,7 +449,7 @@ def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaF
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
Reference in New Issue
Block a user