diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 27347b4476..375884d327 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -5,25 +5,29 @@ import argparse from enum import IntEnum from pathlib import Path +import pkgutil +import sys from typing import List, Optional +import codegen.ops from codegen.cmake_config import * -from codegen.ops import ( - fmha_fwd, - fmha_fwd_splitkv, - fmha_bwd -) - class HandlerId(IntEnum): LIST_BLOBS = 0 WRITE_BLOBS = 1 -handlers = { - 'fwd' : (fmha_fwd.list_blobs, fmha_fwd.write_blobs), - 'fwd_splitkv' : (fmha_fwd_splitkv.list_blobs, fmha_fwd_splitkv.write_blobs), - 'bwd' : (fmha_bwd.list_blobs, fmha_bwd.write_blobs), -} +# inspect all modules under 'codegen.ops' and register API handlers +ops = [] +for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__): + full_module_name = '%s.%s' % (codegen.ops.__name__, module_name) + if full_module_name not in sys.modules: + ops.append(importer.find_module(module_name).load_module(module_name)) +unwanted_prefix = 'fmha_' +handlers = dict( + [(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__, + (op.list_blobs, op.write_blobs)) for op in ops] +) +assert 0 < len(handlers) def write_blobs(output_dir: Optional[str], api_list : List[str], kernel_filter : Optional[str], receipt, mask_impl) -> None: if output_dir is None: