This commit is contained in:
carlushuang
2024-03-06 14:31:36 +00:00
parent f549bb5d39
commit 0e7df1999f
10 changed files with 268 additions and 68 deletions

View File

@@ -309,30 +309,83 @@ bool run(const ck_tile::ArgParser& arg_parser)
mask.type,
use_bias,
lse};
auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
batch,
nhead,
nhead_k,
shape_seqlen_q,
shape_seqlen_k,
hdim_q,
hdim_v,
max_seqlen_q,
scale,
descale_q * descale_k,
descale_v,
i_perm,
o_perm,
mask.y,
mask.x};
auto fmha_args = [&]() {
assert(nhead % nhead_k == 0);
/// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q,
/// seqlen_k] in this example, hence both the 'batch_stride_bias' &
/// 'nhead_stride_bias' are 0.
// setup stride_* arguments
const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q);
const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q);
const ck_tile::index_t stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? hdim_v : nhead_k * hdim_v;
else
return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k;
}();
const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k);
const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v);
// setup nhead_stride_* arguments
const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q);
const ck_tile::index_t nhead_stride_v = [&]() {
if(is_v_rowmajor)
return i_perm ? shape_seqlen_k * hdim_v : hdim_v;
else
return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k;
}();
const ck_tile::index_t nhead_stride_bias =
(i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k);
const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1);
const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v);
// setup batch_stride_* arguments
const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q);
const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q);
const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k);
const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k);
const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1);
const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v);
return fmha_fwd_args{q_buf.GetDeviceBuffer(),
k_buf.GetDeviceBuffer(),
v_buf.GetDeviceBuffer(),
bias_buf.GetDeviceBuffer(),
lse_buf.GetDeviceBuffer(),
o_buf.GetDeviceBuffer(),
seqstart_q.GetDeviceBuffer(),
seqstart_k.GetDeviceBuffer(),
nullptr,
shape_seqlen_q,
shape_seqlen_k,
batch,
max_seqlen_q,
hdim_q,
hdim_v,
nhead,
nhead_k,
scale,
stride_q,
stride_k,
stride_v,
stride_bias,
stride_o,
nhead_stride_q,
nhead_stride_k,
nhead_stride_v,
nhead_stride_bias,
nhead_stride_lse,
nhead_stride_o,
batch_stride_q,
batch_stride_k,
batch_stride_v,
batch_stride_bias,
batch_stride_lse,
batch_stride_o,
mask.y,
mask.x,
descale_q * descale_k,
descale_v};
}();
float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config);

View File

@@ -80,6 +80,7 @@ struct FmhaMasks
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
};
#if 0
// internal API, don't use this directly
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(const void* q_ptr,
@@ -248,34 +249,129 @@ struct fmha_fwd_args
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
};
#endif
// runtime args, some will passed to karg, some will used to compute grids/blocks
struct fmha_fwd_args
{
const void* q_ptr;
const void* k_ptr;
const void* v_ptr;
const void* bias_ptr;
void* lse_ptr;
void* o_ptr;
const void* seqstart_q_ptr;
const void* seqstart_k_ptr;
const void* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t nhead_q;
ck_tile::index_t nhead_k;
float scale;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_v;
ck_tile::index_t stride_bias;
ck_tile::index_t stride_o;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_bias;
ck_tile::index_t nhead_stride_lse;
ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_bias;
ck_tile::index_t batch_stride_lse;
ck_tile::index_t batch_stride_o;
ck_tile::index_t mask_y;
ck_tile::index_t mask_x;
float descale_qk;
float descale_sv;
};
template <typename FmhaKernel>
auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args)
{
return fmha_fwd_create_kargs_and_grids<FmhaKernel>(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.batch,
args.nhead,
args.nhead_k,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.max_seqlen_q,
args.scale,
args.descale_qk,
args.descale_sv,
args.i_perm,
args.o_perm,
args.mask_y,
args.mask_x);
assert(args.nhead_q % args.nhead_k == 0);
auto kargs = [&] {
// create group mode kernel arguments
if constexpr(FmhaKernel::kIsGroupMode)
{
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.o_ptr,
args.seqstart_q_ptr,
args.seqstart_k_ptr,
args.seqlen_k_ptr,
args.hdim_q,
args.hdim_v,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.mask_y,
args.mask_x,
args.descale_qk,
args.descale_sv);
}
else
{ // create batch mode kernel arguments
return FmhaKernel::MakeKargs(args.q_ptr,
args.k_ptr,
args.v_ptr,
args.bias_ptr,
args.lse_ptr,
args.o_ptr,
args.seqlen_q,
args.seqlen_k,
args.hdim_q,
args.hdim_v,
args.nhead_q / args.nhead_k,
args.scale,
args.stride_q,
args.stride_k,
args.stride_v,
args.stride_bias,
args.stride_o,
args.nhead_stride_q,
args.nhead_stride_k,
args.nhead_stride_v,
args.nhead_stride_bias,
args.nhead_stride_lse,
args.nhead_stride_o,
args.batch_stride_q,
args.batch_stride_k,
args.batch_stride_v,
args.batch_stride_bias,
args.batch_stride_lse,
args.batch_stride_o,
args.mask_y,
args.mask_x,
args.descale_qk,
args.descale_sv);
}
}();
dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v);
return ck_tile::make_tuple(kargs, grids);
}
// this is used to pattern-match internl kernel implementation, not to instantiate kernel

View File

@@ -8,6 +8,7 @@ from pathlib import Path
from typing import List, Optional
from dataclasses import dataclass
import copy
import fnmatch
DTYPE_MAP = {
"fp16": "ck_tile::half_t",
@@ -402,7 +403,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[
else:
return None
def get_blobs() -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
def get_blobs(kernel_filter : Optional[str]) -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
# support this in future
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
@@ -443,6 +444,9 @@ def get_blobs() -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
hdim = int(hdim_str)
for pipeline in get_pipelines(dtype, hdim):
k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline)
if kernel_filter != None:
if not fnmatch.fnmatch(k.name, kernel_filter):
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
@@ -454,24 +458,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
def write_blobs(output_dir: Optional[str]) -> None:
def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None:
if output_dir is None:
output_dir = Path(__file__).parent
else:
output_dir = Path(output_dir) / GEN_DIR
output_dir.mkdir(parents=True, exist_ok=True)
api_pool, kernels = get_blobs()
api_pool, kernels = get_blobs(kernel_filter)
for kernel in kernels:
write_single_kernel(kernel, output_dir)
write_api(api_pool, output_dir)
# list all the files that will be generated
def list_blobs(output_file: Optional[str]) -> None:
def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None:
assert output_file is not None
file_path = Path(output_file)
with file_path.open('a') as f:
_, kernels = get_blobs()
_, kernels = get_blobs(kernel_filter)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
@@ -493,8 +497,15 @@ if __name__ == "__main__":
required=False,
help="list all the kernels to a file"
)
# TODO: if using filter, must apply same value to output_dir and list_blobs
parser.add_argument(
"-f",
"--filter",
required=False,
help="filter out kernels that need to generate, using fnmatch module"
)
args = parser.parse_args()
if args.list_blobs is not None:
list_blobs(args.list_blobs)
list_blobs(args.list_blobs, args.filter)
else:
write_blobs(args.output_dir)
write_blobs(args.output_dir, args.filter)

View File

@@ -48,7 +48,7 @@
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0
#define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1
#ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY
#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE
#endif
#ifndef CK_TILE_USE_LAUNCH_BOUNDS

View File

@@ -87,11 +87,29 @@ CK_TILE_HOST_DEVICE constexpr T max(T x)
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T max(T x, T y)
CK_TILE_HOST constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T max(T x, T y)
{
return x > y ? x : y;
}
template <>
CK_TILE_DEVICE constexpr float max(float x, float y)
{
return __builtin_fmaxf(x, y); // can resultin v_max3_f32
}
template <>
CK_TILE_DEVICE constexpr double max(double x, double y)
{
return __builtin_fmax(x, y); // maybe still v_max3_f32
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t max(number<X>, index_t y)
{
@@ -118,11 +136,29 @@ CK_TILE_HOST_DEVICE constexpr T min(T x)
}
template <typename T>
CK_TILE_HOST_DEVICE constexpr T min(T x, T y)
CK_TILE_HOST constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <typename T>
CK_TILE_DEVICE constexpr T min(T x, T y)
{
return x < y ? x : y;
}
template <>
CK_TILE_DEVICE constexpr float min(float x, float y)
{
return __builtin_fminf(x, y);
}
template <>
CK_TILE_DEVICE constexpr double min(double x, double y)
{
return __builtin_fmin(x, y);
}
template <index_t X>
CK_TILE_HOST_DEVICE constexpr index_t min(number<X>, index_t y)
{

View File

@@ -60,7 +60,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().GetNumOfDimensionY();
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());

View File

@@ -78,9 +78,9 @@ struct tile_distribution
Ys2DDescriptor ys_to_d_;
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; }
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionY() { return NDimY; }
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionP() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionR() { return NDimR; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; }
CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; }
CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
{

View File

@@ -36,8 +36,8 @@ struct tile_window_with_static_distribution
static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP();
static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY();
static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
@@ -265,7 +265,7 @@ struct tile_window_with_static_distribution
window_adaptor_vector_lengths, window_adaptor_vector_strides);
// [y0, y1, ...]
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::GetNumOfDimensionP(),
constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
NDimWindowAdaptorTop,
1>::type{};

View File

@@ -38,7 +38,11 @@ struct Default2DEpilogue
// TODO: this is ugly
if constexpr(kPadM || kPadN)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
// o_dram_window_tmp.foo();
// ODataType{}.foo();
// o_acc_tile.foo();
auto x = cast_tile<ODataType>(o_acc_tile);
store_tile_raw(o_dram_window_tmp, x);
buffer_store_fence();
}
else

View File

@@ -17,8 +17,8 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::GetNumOfDimensionP();
constexpr index_t NDimR = Dstr::GetNumOfDimensionR();
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;