Merge branch 'ck_tile/refactor' into ck_tile/elementwise

This commit is contained in:
rocking
2024-04-01 16:07:27 +08:00
committed by GitHub
10 changed files with 163 additions and 9 deletions

View File

@@ -15,7 +15,10 @@ add_custom_command(
)
set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
add_example_executable(${EXAMPLE_FMHA_FWD} fmha_fwd.cpp)
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding tile_example ${EXAMPLE_NAME}")
add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL fmha_fwd.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
target_sources(${EXAMPLE_FMHA_FWD} PRIVATE ${FMHA_FWD_GEN_BLOBS})

View File

@@ -6,7 +6,7 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck_tile-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_fmha_fwd -j
```
This will result in an executable `build/bin/tile_example_fmha_fwd`

View File

@@ -5,7 +5,7 @@
import argparse
import itertools
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from dataclasses import dataclass
import copy
import fnmatch
@@ -414,7 +414,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_blobs(kernel_filter : Optional[str]) -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_blobs(kernel_filter : Optional[str]) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@@ -439,7 +439,7 @@ def get_blobs(kernel_filter : Optional[str]) -> tuple[FmhaFwdApiPool, List[FmhaF
for mask, bias in itertools.product(MASK_MAP.keys(), ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr_fp8', 'col', 'f', 'f', 'f', 'f', bias, 'f', mask))
else:
assert Fasle
assert False
return pipelines
gen = list()

View File

@@ -9,6 +9,7 @@
#include <tuple>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core/container/span.hpp"