diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index f54049cfcc..5a6afe36f6 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -309,30 +309,83 @@ bool run(const ck_tile::ArgParser& arg_parser) mask.type, use_bias, lse}; - auto fmha_args = fmha_fwd_args{q_buf.GetDeviceBuffer(), - k_buf.GetDeviceBuffer(), - v_buf.GetDeviceBuffer(), - bias_buf.GetDeviceBuffer(), - lse_buf.GetDeviceBuffer(), - o_buf.GetDeviceBuffer(), - seqstart_q.GetDeviceBuffer(), - seqstart_k.GetDeviceBuffer(), - nullptr, - batch, - nhead, - nhead_k, - shape_seqlen_q, - shape_seqlen_k, - hdim_q, - hdim_v, - max_seqlen_q, - scale, - descale_q * descale_k, - descale_v, - i_perm, - o_perm, - mask.y, - mask.x}; + auto fmha_args = [&]() { + assert(nhead % nhead_k == 0); + /// NOTE: we broadcast bias from [1, 1, seqlen_q, seqlen_k] to [batch, nhead, seqlen_q, + /// seqlen_k] in this example, hence both the 'batch_stride_bias' & + /// 'nhead_stride_bias' are 0. + // setup stride_* arguments + const ck_tile::index_t stride_q = (i_perm ? hdim_q : nhead * hdim_q); + const ck_tile::index_t stride_k = (i_perm ? hdim_q : nhead_k * hdim_q); + const ck_tile::index_t stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? hdim_v : nhead_k * hdim_v; + else + return i_perm ? shape_seqlen_k : nhead_k * shape_seqlen_k; + }(); + const ck_tile::index_t stride_bias = (i_perm ? shape_seqlen_k : 1 * shape_seqlen_k); + const ck_tile::index_t stride_o = (o_perm ? hdim_v : nhead * hdim_v); + // setup nhead_stride_* arguments + const ck_tile::index_t nhead_stride_q = (i_perm ? shape_seqlen_q * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_k = (i_perm ? shape_seqlen_k * hdim_q : hdim_q); + const ck_tile::index_t nhead_stride_v = [&]() { + if(is_v_rowmajor) + return i_perm ? shape_seqlen_k * hdim_v : hdim_v; + else + return i_perm ? hdim_v * shape_seqlen_k : shape_seqlen_k; + }(); + const ck_tile::index_t nhead_stride_bias = + (i_perm ? 0 * shape_seqlen_q * shape_seqlen_k : 0 * shape_seqlen_k); + const ck_tile::index_t nhead_stride_lse = (shape_seqlen_q * 1); + const ck_tile::index_t nhead_stride_o = (o_perm ? shape_seqlen_q * hdim_v : hdim_v); + // setup batch_stride_* arguments + const ck_tile::index_t batch_stride_q = (nhead * shape_seqlen_q * hdim_q); + const ck_tile::index_t batch_stride_k = (nhead_k * shape_seqlen_k * hdim_q); + const ck_tile::index_t batch_stride_v = (nhead_k * hdim_v * shape_seqlen_k); + const ck_tile::index_t batch_stride_bias = (0 * nhead * shape_seqlen_q * shape_seqlen_k); + const ck_tile::index_t batch_stride_lse = (nhead * shape_seqlen_q * 1); + const ck_tile::index_t batch_stride_o = (nhead * shape_seqlen_q * hdim_v); + + return fmha_fwd_args{q_buf.GetDeviceBuffer(), + k_buf.GetDeviceBuffer(), + v_buf.GetDeviceBuffer(), + bias_buf.GetDeviceBuffer(), + lse_buf.GetDeviceBuffer(), + o_buf.GetDeviceBuffer(), + seqstart_q.GetDeviceBuffer(), + seqstart_k.GetDeviceBuffer(), + nullptr, + shape_seqlen_q, + shape_seqlen_k, + batch, + max_seqlen_q, + hdim_q, + hdim_v, + nhead, + nhead_k, + scale, + stride_q, + stride_k, + stride_v, + stride_bias, + stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_lse, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_lse, + batch_stride_o, + mask.y, + mask.x, + descale_q * descale_k, + descale_v}; + }(); float ave_time = fmha_fwd(fmha_traits, fmha_args, stream_config); diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 49846a322d..9293201cd2 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -80,6 +80,7 @@ struct FmhaMasks using CausalMask = ck_tile::GenericAttentionMask; }; +#if 0 // internal API, don't use this directly template auto fmha_fwd_create_kargs_and_grids(const void* q_ptr, @@ -248,34 +249,129 @@ struct fmha_fwd_args ck_tile::index_t mask_y; ck_tile::index_t mask_x; }; +#endif + +// runtime args, some will passed to karg, some will used to compute grids/blocks +struct fmha_fwd_args +{ + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* lse_ptr; + void* o_ptr; + const void* seqstart_q_ptr; + const void* seqstart_k_ptr; + const void* seqlen_k_ptr; + ck_tile::index_t seqlen_q; + ck_tile::index_t seqlen_k; + ck_tile::index_t batch; + ck_tile::index_t max_seqlen_q; + ck_tile::index_t hdim_q; + ck_tile::index_t hdim_v; + ck_tile::index_t nhead_q; + ck_tile::index_t nhead_k; + float scale; + ck_tile::index_t stride_q; + ck_tile::index_t stride_k; + ck_tile::index_t stride_v; + ck_tile::index_t stride_bias; + ck_tile::index_t stride_o; + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_lse; + ck_tile::index_t nhead_stride_o; + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_lse; + ck_tile::index_t batch_stride_o; + ck_tile::index_t mask_y; + ck_tile::index_t mask_x; + float descale_qk; + float descale_sv; +}; template auto fmha_fwd_create_kargs_and_grids(fmha_fwd_args args) { - return fmha_fwd_create_kargs_and_grids(args.q_ptr, - args.k_ptr, - args.v_ptr, - args.bias_ptr, - args.lse_ptr, - args.o_ptr, - args.seqstart_q_ptr, - args.seqstart_k_ptr, - args.seqlen_k_ptr, - args.batch, - args.nhead, - args.nhead_k, - args.seqlen_q, - args.seqlen_k, - args.hdim_q, - args.hdim_v, - args.max_seqlen_q, - args.scale, - args.descale_qk, - args.descale_sv, - args.i_perm, - args.o_perm, - args.mask_y, - args.mask_x); + assert(args.nhead_q % args.nhead_k == 0); + auto kargs = [&] { + // create group mode kernel arguments + if constexpr(FmhaKernel::kIsGroupMode) + { + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqstart_q_ptr, + args.seqstart_k_ptr, + args.seqlen_k_ptr, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.mask_y, + args.mask_x, + args.descale_qk, + args.descale_sv); + } + else + { // create batch mode kernel arguments + return FmhaKernel::MakeKargs(args.q_ptr, + args.k_ptr, + args.v_ptr, + args.bias_ptr, + args.lse_ptr, + args.o_ptr, + args.seqlen_q, + args.seqlen_k, + args.hdim_q, + args.hdim_v, + args.nhead_q / args.nhead_k, + args.scale, + args.stride_q, + args.stride_k, + args.stride_v, + args.stride_bias, + args.stride_o, + args.nhead_stride_q, + args.nhead_stride_k, + args.nhead_stride_v, + args.nhead_stride_bias, + args.nhead_stride_lse, + args.nhead_stride_o, + args.batch_stride_q, + args.batch_stride_k, + args.batch_stride_v, + args.batch_stride_bias, + args.batch_stride_lse, + args.batch_stride_o, + args.mask_y, + args.mask_x, + args.descale_qk, + args.descale_sv); + } + }(); + + dim3 grids = FmhaKernel::GridSize(args.batch, args.nhead_q, args.max_seqlen_q, args.hdim_v); + return ck_tile::make_tuple(kargs, grids); } // this is used to pattern-match internl kernel implementation, not to instantiate kernel diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 0ffc1fcb8b..f2b7a61c17 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -8,6 +8,7 @@ from pathlib import Path from typing import List, Optional from dataclasses import dataclass import copy +import fnmatch DTYPE_MAP = { "fp16": "ck_tile::half_t", @@ -402,7 +403,7 @@ def get_fmha_fwd_tile_dict_from_dtype(direction : str, dtype : str) -> Optional[ else: return None -def get_blobs() -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: +def get_blobs(kernel_filter : Optional[str]) -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: @@ -443,6 +444,9 @@ def get_blobs() -> tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: hdim = int(hdim_str) for pipeline in get_pipelines(dtype, hdim): k = FmhaFwdKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_mode=mode, F_tile=tile, F_pipeline=pipeline) + if kernel_filter != None: + if not fnmatch.fnmatch(k.name, kernel_filter): + continue api_pool.register_traits(k.api_trait()) gen.append(k) @@ -454,24 +458,24 @@ def write_single_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None: def write_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None: (autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api) -def write_blobs(output_dir: Optional[str]) -> None: +def write_blobs(output_dir : Optional[str], kernel_filter : Optional[str]) -> None: if output_dir is None: output_dir = Path(__file__).parent else: output_dir = Path(output_dir) / GEN_DIR output_dir.mkdir(parents=True, exist_ok=True) - api_pool, kernels = get_blobs() + api_pool, kernels = get_blobs(kernel_filter) for kernel in kernels: write_single_kernel(kernel, output_dir) write_api(api_pool, output_dir) # list all the files that will be generated -def list_blobs(output_file: Optional[str]) -> None: +def list_blobs(output_file : Optional[str], kernel_filter : Optional[str]) -> None: assert output_file is not None file_path = Path(output_file) with file_path.open('a') as f: - _, kernels = get_blobs() + _, kernels = get_blobs(kernel_filter) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n") @@ -493,8 +497,15 @@ if __name__ == "__main__": required=False, help="list all the kernels to a file" ) + # TODO: if using filter, must apply same value to output_dir and list_blobs + parser.add_argument( + "-f", + "--filter", + required=False, + help="filter out kernels that need to generate, using fnmatch module" + ) args = parser.parse_args() if args.list_blobs is not None: - list_blobs(args.list_blobs) + list_blobs(args.list_blobs, args.filter) else: - write_blobs(args.output_dir) + write_blobs(args.output_dir, args.filter) diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 3a318ec091..a2362fb46b 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -48,7 +48,7 @@ #define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY 0 #define CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE 1 #ifndef CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT -#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_ARRAY +#define CK_TILE_STATICALLY_INDEXED_ARRAY_DEFAULT CK_TILE_STATICALLY_INDEXED_ARRAY_USE_TUPLE #endif #ifndef CK_TILE_USE_LAUNCH_BOUNDS diff --git a/include/ck_tile/core/numeric/math.hpp b/include/ck_tile/core/numeric/math.hpp index 90a8084b85..0c67a640af 100644 --- a/include/ck_tile/core/numeric/math.hpp +++ b/include/ck_tile/core/numeric/math.hpp @@ -87,11 +87,29 @@ CK_TILE_HOST_DEVICE constexpr T max(T x) } template -CK_TILE_HOST_DEVICE constexpr T max(T x, T y) +CK_TILE_HOST constexpr T max(T x, T y) { return x > y ? x : y; } +template +CK_TILE_DEVICE constexpr T max(T x, T y) +{ + return x > y ? x : y; +} + +template <> +CK_TILE_DEVICE constexpr float max(float x, float y) +{ + return __builtin_fmaxf(x, y); // can resultin v_max3_f32 +} + +template <> +CK_TILE_DEVICE constexpr double max(double x, double y) +{ + return __builtin_fmax(x, y); // maybe still v_max3_f32 +} + template CK_TILE_HOST_DEVICE constexpr index_t max(number, index_t y) { @@ -118,11 +136,29 @@ CK_TILE_HOST_DEVICE constexpr T min(T x) } template -CK_TILE_HOST_DEVICE constexpr T min(T x, T y) +CK_TILE_HOST constexpr T min(T x, T y) { return x < y ? x : y; } +template +CK_TILE_DEVICE constexpr T min(T x, T y) +{ + return x < y ? x : y; +} + +template <> +CK_TILE_DEVICE constexpr float min(float x, float y) +{ + return __builtin_fminf(x, y); +} + +template <> +CK_TILE_DEVICE constexpr double min(double x, double y) +{ + return __builtin_fmin(x, y); +} + template CK_TILE_HOST_DEVICE constexpr index_t min(number, index_t y) { diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index e1bd9c4d19..a756679bd9 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -60,7 +60,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT }(); // - constexpr index_t NDimY = InTensor::get_tile_distribution().GetNumOfDimensionY(); + constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 81fa864340..4ee6ef72f5 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -78,9 +78,9 @@ struct tile_distribution Ys2DDescriptor ys_to_d_; CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_x() { return NDimX; } - CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionY() { return NDimY; } - CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionP() { return NDimP; } - CK_TILE_HOST_DEVICE static constexpr index_t GetNumOfDimensionR() { return NDimR; } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_y() { return NDimY; } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_p() { return NDimP; } + CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_dimension_r() { return NDimR; } CK_TILE_HOST_DEVICE static constexpr auto get_lengths() { diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index be4c67dbce..643f6d77ef 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -36,8 +36,8 @@ struct tile_window_with_static_distribution static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension(); static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension(); - static constexpr index_t NDimP = TileDstr::GetNumOfDimensionP(); - static constexpr index_t NDimY = TileDstr::GetNumOfDimensionY(); + static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p(); + static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y(); static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -265,7 +265,7 @@ struct tile_window_with_static_distribution window_adaptor_vector_lengths, window_adaptor_vector_strides); // [y0, y1, ...] - constexpr auto y_dims = typename arithmetic_sequence_gen::type{}; diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 5dc49c3b0e..2bfbb8b38f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -38,7 +38,11 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(kPadM || kPadN) { - store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); + // o_dram_window_tmp.foo(); + // ODataType{}.foo(); + // o_acc_tile.foo(); + auto x = cast_tile(o_acc_tile); + store_tile_raw(o_dram_window_tmp, x); buffer_store_fence(); } else diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 176b870ec9..682d60d872 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -17,8 +17,8 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, using DstrEncode = typename Dstr::DstrEncode; using DstrEncodeDetail = typename DstrEncode::detail; - constexpr index_t NDimP = Dstr::GetNumOfDimensionP(); - constexpr index_t NDimR = Dstr::GetNumOfDimensionR(); + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); constexpr index_t idim_p_lane = NDimP - 1;