mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 01:27:40 +00:00
* chore(copyright): update copyright header for codegen directory * chore(copyright): update copyright header for example directory
43 lines
1.4 KiB
Python
43 lines
1.4 KiB
Python
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, List, Callable
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ArchTrait:
|
|
name: str
|
|
preprocessor_check: str = field(default=None)
|
|
device_name_check: str = field(default=None)
|
|
tag: str = field(default=None)
|
|
filename_suffix: str = field(default=None)
|
|
|
|
def __post_init__(self):
|
|
if self.preprocessor_check is None:
|
|
object.__setattr__(self, "preprocessor_check", f"defined(__{self.name}__)")
|
|
if self.device_name_check is None:
|
|
object.__setattr__(
|
|
self,
|
|
"device_name_check",
|
|
f'device_name.compare(0, {len(self.name)}, "{self.name}") == 0',
|
|
)
|
|
if self.tag is None:
|
|
object.__setattr__(self, "tag", f"ck_tile::{self.name}_t")
|
|
if self.filename_suffix is None:
|
|
object.__setattr__(self, "filename_suffix", f"_{self.name}")
|
|
|
|
|
|
def get_factories_for_targets(
|
|
targets: List[str], get_factory: Callable[[str], Any]
|
|
) -> List[Any]:
|
|
factories = dict()
|
|
for target in targets:
|
|
factory = get_factory(target)
|
|
factories[factory.arch.name] = factory
|
|
# Place more specific architectures first
|
|
factories = sorted(
|
|
list(factories.values()), key=lambda f: len(f.arch.name), reverse=True
|
|
)
|
|
return factories
|