mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
Merge branch 'ck_tile/refactor' into ck_tile/elementwise
This commit is contained in:
@@ -15,7 +15,10 @@ add_custom_command(
|
||||
)
|
||||
|
||||
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp)
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding tile_example ${EXAMPLE_NAME}")
|
||||
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
|
||||
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
|
||||
```
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
sh ../script/cmake-ck_tile-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
|
||||
make tile_example_fmha_fwd -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_fmha_fwd`
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
import argparse
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
import copy
|
||||
import fnmatch
|
||||
@@ -414,7 +414,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_blobs(kernel_filter : Optional[str]) -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
@@ -439,7 +439,7 @@ def get_blobs(kernel_filter : Optional[str]) -> tuple[FmhaFwdApiPool, List[FmhaF
|
||||
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
|
||||
else:
|
||||
assert Fasle
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <functional>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user