mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Generate the instance for FA required
This commit is contained in:
@@ -429,7 +429,7 @@ class FmhaFwdKernel:
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
@@ -558,6 +558,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[Fm
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
@@ -652,7 +659,7 @@ using fmha_bwd_dv_epilogue_{F_idx} =
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
false, false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<ck_tile::FmhaBwdTilePartitioner<fmha_bwd_shape_{F_idx}>,
|
||||
fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
@@ -816,7 +823,7 @@ class FmhaBwdApiPool:
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype],
|
||||
F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad])
|
||||
|
||||
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_BWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
@@ -863,7 +870,7 @@ class FmhaBwdDQDKDVTileSize:
|
||||
@dataclass
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
@@ -982,7 +989,7 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction : str, dtype : str)
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
|
||||
def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaBwdApiPool, List[FmhaBwdDQDKDVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad
|
||||
# support this in future
|
||||
gen = list()
|
||||
@@ -1001,12 +1008,17 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[Fm
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad,
|
||||
F_bias=bias, F_dbias=dbias, F_dropout=dropout, F_mask=mask, F_mode=mode,
|
||||
F_pipeline=ppl, mask_impl=mask_impl)
|
||||
if kernel_filter != None:
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
if not cond:
|
||||
continue
|
||||
api_pool.register_dq_dk_dv_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
@@ -1031,7 +1043,7 @@ using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDot
|
||||
using fmha_bwd_dot_do_o_{F_idx} = typename ck_tile::BlockFmhaBwdOGradDotO<
|
||||
fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<ck_tile::FmhaBwdOGradDotOTilePartitioner</* BlockSize = */ 256>,
|
||||
fmha_bwd_dot_do_o_{F_idx}>;
|
||||
|
||||
@@ -1072,7 +1084,7 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
@dataclass
|
||||
class FmhaBwdOGradDotOKernel:
|
||||
direction : str
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_spad : str # true/false
|
||||
@@ -1163,7 +1175,7 @@ def write_blobs(output_dir: Optional[str], direction: str, kernel_filter : Optio
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dot_do_o_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl)
|
||||
api_pool, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_bwd_dq_dk_dv_kernel(kernel, output_dir)
|
||||
write_bwd_api(api_pool, output_dir)
|
||||
@@ -1182,7 +1194,7 @@ def list_blobs(output_file : Optional[str], direction : str, kernel_filter : Opt
|
||||
kernels = get_bwd_dot_do_o_blobs()
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, mask_impl)
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
@@ -1234,11 +1246,12 @@ if __name__ == "__main__":
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
||||
" 1: generate more instance to cover all hdim"
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, args.direction, args.filter, args.receipt, mask_impl=args.mask)
|
||||
list_blobs(args.list_blobs, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, args.direction, args.filter, args.receipt, mask_impl=args.mask)
|
||||
write_blobs(args.output_dir, args.direction, args.filter, int(args.receipt), mask_impl=args.mask)
|
||||
|
||||
Reference in New Issue
Block a user