mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
fix: python 3.8 compatibility in fmha codegen (#3388)
This commit is contained in:
@@ -770,7 +770,7 @@ def create_kernel(
|
||||
|
||||
class CompatibilityRuleFactory:
|
||||
@staticmethod
|
||||
def get_rules() -> list[CompatibilityRule]:
|
||||
def get_rules() -> List[CompatibilityRule]:
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
def check_mode(problem_ctx: ProblemContext, kernel_ctx: KernelContext) -> bool:
|
||||
if problem_ctx.mode == "group":
|
||||
@@ -812,7 +812,7 @@ class CompatibilityRuleFactoryGfx9(CompatibilityRuleFactory):
|
||||
_AVAILABLE_PIPELINES = frozenset({"qr", "qr_async", "qs"})
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> list[CompatibilityRule]:
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactory.get_rules()
|
||||
|
||||
def check_hdim_tile(
|
||||
@@ -846,7 +846,7 @@ class CompatibilityRuleFactoryGfx950(CompatibilityRuleFactoryGfx9):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_rules(cls) -> list[CompatibilityRule]:
|
||||
def get_rules(cls) -> List[CompatibilityRule]:
|
||||
rules = CompatibilityRuleFactoryGfx9.get_rules()
|
||||
|
||||
def check_tile_pipeline(
|
||||
|
||||
Reference in New Issue
Block a user