From bd9cd53885f9b086f72e804ade3e6dcf286a96fc Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 12 May 2024 22:33:22 +0000 Subject: [PATCH] now fwd/bwd can build --- example/ck_tile/01_fmha/CMakeLists.txt | 6 ++++ example/ck_tile/01_fmha/fmha_bwd.cpp | 18 +++++++----- example/ck_tile/01_fmha/fmha_bwd.hpp | 7 +++-- example/ck_tile/01_fmha/fmha_fwd.cpp | 6 ++-- example/ck_tile/01_fmha/fmha_fwd.hpp | 2 +- example/ck_tile/01_fmha/generate.py | 12 ++++---- include/ck_tile/core.hpp | 6 ++-- .../ck_tile/core/tensor/tile_distribution.hpp | 1 + include/ck_tile/host.hpp | 2 +- include/ck_tile/ops/fmha.hpp | 28 +++++++++---------- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 24 ++++++++++------ .../ops/fmha/kernel/fmha_fwd_kernel.hpp | 4 +-- ...k_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp | 13 +++++---- ...block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp | 13 +++++---- ...mha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp | 13 +++++---- ...block_fmha_bwd_pipeline_default_policy.hpp | 2 +- .../block_fmha_bwd_pipeline_problem.hpp | 2 +- include/ck_tile/ops/gemm.hpp | 5 +++- 18 files changed, 94 insertions(+), 70 deletions(-) diff --git a/example/ck_tile/01_fmha/CMakeLists.txt b/example/ck_tile/01_fmha/CMakeLists.txt index d81218c79e..e324f85ed8 100644 --- a/example/ck_tile/01_fmha/CMakeLists.txt +++ b/example/ck_tile/01_fmha/CMakeLists.txt @@ -67,3 +67,9 @@ list(APPEND EXAMPLE_FMHA_BWD_COMPILE_OPTIONS -Wno-float-equal) target_compile_options(${EXAMPLE_FMHA_FWD} PRIVATE ${EXAMPLE_FMHA_FWD_COMPILE_OPTIONS}) target_compile_options(${EXAMPLE_FMHA_BWD} PRIVATE ${EXAMPLE_FMHA_BWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/01_fmha/fmha_bwd.cpp b/example/ck_tile/01_fmha/fmha_bwd.cpp index 37a21b25ef..5a07f713e3 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.cpp +++ b/example/ck_tile/01_fmha/fmha_bwd.cpp @@ -57,7 +57,11 @@ auto create_args(int argc, char* argv[]) "permute input\n" "if true, will be b*h*s*d, else b*s*h*d") .insert("operm", "1", "permute output") - .insert("bias", "0", "add bias or not") + .insert("bias", + "n", + "n or 0, no bias\n" + "e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n" + "a(libi) or 2, alibi with 1*h. a:1, b*h") .insert("dbias", "0", "output bias gradient or not") .insert("prec", "fp16", "data type. fp16 or bf16") .insert("mask", @@ -136,12 +140,12 @@ bool run(const ck_tile::ArgParser& arg_parser) if(scale == .0f) scale = 1.0 / ck_tile::sqrt(static_cast(hdim_q)); - bool use_bias = arg_parser.get_bool("bias"); + bias_info bias = bias_info::decode(arg_parser.get_str("bias")); bool use_dbias = arg_parser.get_bool("dbias"); float p_drop = arg_parser.get_float("p_drop"); uint64_t drop_seed = arg_parser.get_uint64("drop_seed"); uint64_t drop_offset = arg_parser.get_uint64("drop_offset"); - if(use_dbias && !use_bias) + if(use_dbias && bias.type != bias_enum::elementwise_bias) { std::cerr << "dbias only exists when there is a bias" << std::endl; return false; @@ -266,7 +270,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // use bias shape = [1, 1, shape_seqlen_q, max_seqlen_k]. if use_bias=false, the bias_host // will not be used for verification at all (but will be copied to device anyway). ck_tile::HostTensor bias_host( - use_bias + bias.type == bias_enum::elementwise_bias ? get_lengths(i_perm, 1, 1, shape_seqlen_q, max_seqlen_k) : std::array{1, 1, 1, 1} /* dummy shape for simplifying code */); ck_tile::HostTensor o_host( @@ -354,7 +358,7 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k - << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << use_bias + << ", d:" << hdim_q << "/" << hdim_v << ", scale:" << scale << ", bias:" << bias << ", dbias:" << use_dbias << ", p_drop:" << p_drop << ", mask:" << mask << std::flush; @@ -363,7 +367,7 @@ bool run(const ck_tile::ArgParser& arg_parser) data_type, mode == mode_enum::group, mask.type, - use_bias, + bias.type, use_dbias, p_drop > 0.0f}; auto fmha_args = [&]() { @@ -556,7 +560,7 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::identity{}, ck_tile::scales(scale)); // s_g_m_n = scale * q_g_m_k@k_g_n_k - if(use_bias) + if(bias.type == bias_enum::elementwise_bias) { // clang-format off ck_tile::HostTensor bias_host_ref({1, real_seqlen_q, real_seqlen_k}); diff --git a/example/ck_tile/01_fmha/fmha_bwd.hpp b/example/ck_tile/01_fmha/fmha_bwd.hpp index 67817bbcd3..9d5148689f 100644 --- a/example/ck_tile/01_fmha/fmha_bwd.hpp +++ b/example/ck_tile/01_fmha/fmha_bwd.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/fmha.hpp" #include "ck_tile/ops/epilogue.hpp" #include "mask.hpp" +#include "bias.hpp" #include template @@ -291,7 +292,7 @@ template ; - static constexpr bool kHasBias = kHasBias_; + static constexpr auto BiasEnum = BiasEnum_; static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasDropout = kHasDropout_; static constexpr bool kPadS = kPadS_; @@ -338,7 +339,7 @@ struct fmha_bwd_traits std::string data_type; bool is_group_mode; mask_enum mask_type; - bool has_bias; + bias_enum bias_type; bool has_dbias; bool has_dropout; // TODO: padding check is inside this api diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index d8461120f7..f5abf7e67a 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -204,7 +204,7 @@ bool run(const ck_tile::ArgParser& arg_parser) // scale_p = [max(fp8_t)/range_o] * [range_p/max(fp8_t)] * [range_v/max(fp8_t)] scale_o = range_p * range_v / range_o / dtype_max; } - + std::string vlayout = arg_parser.get_str("vlayout"); bool lse = arg_parser.get_bool("lse"); @@ -424,8 +424,8 @@ bool run(const ck_tile::ArgParser& arg_parser) std::cout << "[" << prec << "|" << mode << "|" << io_layout(i_perm, o_perm) << "] b:" << batch << ", h:" << nhead << "/" << nhead_k << ", s:" << seqlen_q << "/" << seqlen_k << ", d:" << hdim_q << "/" << hdim_v << ", scale_s:" << scale_s << ", bias:" << bias - << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant << ", mask:" << mask << ", v:" << vlayout - << std::flush; + << ", p_drop:" << p_drop << ", lse:" << lse << ", squant:" << squant + << ", mask:" << mask << ", v:" << vlayout << std::flush; auto fmha_traits = fmha_fwd_traits{hdim_q, hdim_v, diff --git a/example/ck_tile/01_fmha/fmha_fwd.hpp b/example/ck_tile/01_fmha/fmha_fwd.hpp index 0a7a2022f3..3594f61db9 100644 --- a/example/ck_tile/01_fmha/fmha_fwd.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd.hpp @@ -91,7 +91,7 @@ struct fmha_fwd_args const void* q_ptr; const void* k_ptr; const void* v_ptr; - const void* bias_ptr; // bias or alibi_slope pointer + const void* bias_ptr; // bias or alibi_slope pointer void* rand_val_ptr; void* lse_ptr; void* o_ptr; diff --git a/example/ck_tile/01_fmha/generate.py b/example/ck_tile/01_fmha/generate.py index 2a0c164fe0..da7e933c7d 100644 --- a/example/ck_tile/01_fmha/generate.py +++ b/example/ck_tile/01_fmha/generate.py @@ -681,7 +681,7 @@ FMHA_BWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v < }} """ -FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.has_bias == {F_bias}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && +FMHA_BWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && (t.has_dropout == {F_dropout}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{ using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_pipeline_enum}, {F_mask}, {F_bias}, {F_dbias}, {F_dropout}, {F_spad0}, {F_skpad}, {F_dpad}, {F_dvpad}>; using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1}, {F_dvpad}>; @@ -772,7 +772,7 @@ class FmhaBwdApiPool: if ((spad1 == "f" and trait.spad == "t") or (trait.mode == "group" and spad1 == "f")): continue inners = inners + FMHA_BWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_pipeline_enum=BWD_DQDKDV_PIPELINE_ENUM_MAP[trait.pipeline], - F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias=BOOL_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], + F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout=BOOL_MAP[trait.dropout], F_scheck=trait.scheck(spad1=spad1), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=hdim, F_dtype=DTYPE_MAP[dtype], F_spad0=BOOL_MAP[trait.spad], F_spad1=BOOL_MAP[spad1], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad]) @@ -868,7 +868,7 @@ class FmhaBwdDQDKDVKernel: F_skpad = BOOL_MAP[self.F_skpad], F_dpad = BOOL_MAP[self.F_dpad], F_dvpad = BOOL_MAP[self.F_dvpad], - F_bias = BOOL_MAP[self.F_bias], + F_bias = BIAS_MAP[self.F_bias], F_dbias = BOOL_MAP[self.F_dbias], F_dropout = BOOL_MAP[self.F_dropout], F_occupancy = self.F_tile.F_occupancy, @@ -890,7 +890,7 @@ class FmhaBwdDQDKDVKernel: mn = mask_name() n = f"fmha_{self.direction}_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name +\ f"_p{BOOL_MAP[self.F_spad][0]}{BOOL_MAP[self.F_skpad][0]}{BOOL_MAP[self.F_dpad][0]}{BOOL_MAP[self.F_dvpad][0]}" +\ - f"_b{BOOL_MAP[self.F_bias][0]}_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}" + (f'_{self.F_bias}' if self.F_bias != 'no' else '') + f"_db{BOOL_MAP[self.F_dbias][0]}_dp{BOOL_MAP[self.F_dropout][0]}" if mn != '' : n += f'{mn}' return n @@ -944,13 +944,13 @@ def get_bwd_dq_dk_dv_blobs(kernel_filter : Optional[str], mask_impl) -> Tuple[Fm d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(direction, dtype) if d == None: continue - for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): + for hdim_str, mode, mask, bias, dbias, dropout, spad, skpad, dpad, dvpad in itertools.product(d.keys(), MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"], ["t", "f"]): tile = d[hdim_str][0] ppl = d[hdim_str][1] hdim = int(hdim_str) if (mode == "group") and (spad == "f" or skpad == "f"): continue - if (bias == "f" and dbias == "t"): + if (bias == "no" and dbias == "t"): continue k = FmhaBwdDQDKDVKernel(direction=direction, F_idx=0, F_hdim=hdim, F_dtype=dtype, F_tile=tile, F_spad=spad, F_skpad=skpad, F_dpad=dpad, F_dvpad=dvpad, diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 5a175a61c8..bb490cce4a 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -7,8 +7,8 @@ #include "ck_tile/core/algorithm/coordinate_transform.hpp" #include "ck_tile/core/algorithm/space_filling_curve.hpp" #include "ck_tile/core/arch/amd_buffer_addressing.hpp" -#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/core/arch/generic_memory_space_atomic.hpp" #include "ck_tile/core/arch/utility.hpp" #include "ck_tile/core/config.hpp" #include "ck_tile/core/container/array.hpp" @@ -38,7 +38,6 @@ #include "ck_tile/core/tensor/slice_tile.hpp" #include "ck_tile/core/tensor/static_distributed_tensor.hpp" #include "ck_tile/core/tensor/store_tile.hpp" -#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/tensor/sweep_tile.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tensor_adaptor_coordinate.hpp" @@ -49,13 +48,14 @@ #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/tensor/tile_elementwise.hpp" #include "ck_tile/core/tensor/tile_window.hpp" +#include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/ignore.hpp" #include "ck_tile/core/utility/magic_div.hpp" +#include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" #include "ck_tile/core/utility/unary_element_function.hpp" -#include "ck_tile/core/utility/philox_rand.hpp" diff --git a/include/ck_tile/core/tensor/tile_distribution.hpp b/include/ck_tile/core/tensor/tile_distribution.hpp index 9fee2fd5c6..42a30232fb 100644 --- a/include/ck_tile/core/tensor/tile_distribution.hpp +++ b/include/ck_tile/core/tensor/tile_distribution.hpp @@ -9,6 +9,7 @@ #include "ck_tile/core/container/sequence.hpp" #include "ck_tile/core/container/tuple.hpp" #include "ck_tile/core/container/container_helper.hpp" +#include "ck_tile/core/container/meta_data_buffer.hpp" #include "ck_tile/core/tensor/tensor_adaptor.hpp" #include "ck_tile/core/tensor/tile_distribution_encoding.hpp" #include "ck_tile/core/utility/functional.hpp" diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index 62fce34d1a..07a51ff9b6 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -11,11 +11,11 @@ #include "ck_tile/host/host_tensor.hpp" #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/host/ranges.hpp" +#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_batched_elementwise.hpp" #include "ck_tile/host/reference/reference_batched_gemm.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" -#include "ck_tile/host/reference/reference_batched_dropout.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index cacc769c99..5684868306 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -4,11 +4,24 @@ #pragma once #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" -#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp" +#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" +#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp" @@ -19,19 +32,6 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" -#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" -#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index b657646ebe..bfa1bae7f3 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include #include @@ -56,7 +57,7 @@ struct FmhaBwdDQDKDVKernel static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr bool kHasBias = FmhaPipeline::kHasBias; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; static constexpr bool kHasDropout = FmhaPipeline::kHasDropout; using FmhaMask = ck_tile::remove_cvref_t; @@ -91,7 +92,8 @@ struct FmhaBwdDQDKDVKernel _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + "r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + - ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + (kHasBias ? "_bias" : "") + + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr::name)) + (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); #undef _SS_ #undef _TS_ @@ -212,7 +214,9 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdBatchModeKargs : FmhaBwdCommonKargs, - std::conditional_t>, + std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -227,7 +231,9 @@ struct FmhaBwdDQDKDVKernel struct FmhaBwdGroupModeKargs : FmhaBwdCommonKargs, - std::conditional_t>, + std::conditional_t>, std::conditional_t>, std::conditional_t>, std::conditional_t> @@ -336,7 +342,7 @@ struct FmhaBwdDQDKDVKernel batch_stride_dk, batch_stride_dv}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; @@ -458,7 +464,7 @@ struct FmhaBwdDQDKDVKernel reinterpret_cast(seqstart_k_ptr), reinterpret_cast(seqlen_k_ptr)}; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { kargs.bias_ptr = bias_ptr; kargs.stride_bias = stride_bias; @@ -537,7 +543,7 @@ struct FmhaBwdDQDKDVKernel batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; batch_offset_dk = key_start * kargs.stride_dk; batch_offset_dv = key_start * kargs.stride_dv; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = query_start * kargs.stride_bias; } @@ -587,7 +593,7 @@ struct FmhaBwdDQDKDVKernel batch_offset_lsed = static_cast(i_batch) * kargs.batch_stride_lsed; batch_offset_dk = static_cast(i_batch) * kargs.batch_stride_dk; batch_offset_dv = static_cast(i_batch) * kargs.batch_stride_dv; - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; } @@ -919,7 +925,7 @@ struct FmhaBwdDQDKDVKernel constexpr auto bias_dram_window_lengths = make_tuple(number{}, number{}); const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { const BiasDataType* bias_ptr = reinterpret_cast(kargs.bias_ptr) + diff --git a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp index 0bd0741b7b..5939384977 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp @@ -34,8 +34,8 @@ struct FmhaFwdKernel using BiasDataType = ck_tile::remove_cvref_t; using RandValOutputDataType = ck_tile::remove_cvref_t; - using LSEDataType = ck_tile::remove_cvref_t; - using ODataType = ck_tile::remove_cvref_t; + using LSEDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; using SaccDataType = ck_tile::remove_cvref_t; using VLayout = ck_tile::remove_cvref_t; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp index 4b2c469ca9..93f165019b 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp @@ -58,7 +58,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -433,13 +433,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR q_block_tile = load_tile(q_dram_window); // global read 1 } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -484,7 +484,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR } // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { block_sync_lds(); auto bias_shuffle_tmp = make_static_distributed_tensor( @@ -532,7 +532,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR const auto lse = load_tile(lse_dram_window); static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_lse == -numeric::infinity() ? type_convert(0.f) @@ -554,7 +555,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp index ce81b3bfd6..9fc13566d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp @@ -58,7 +58,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -406,13 +406,13 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR q_block_tile = load_tile(q_dram_window); // global read 1 } - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -457,7 +457,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR } // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { block_sync_lds(); auto bias_shuffle_tmp = make_static_distributed_tensor( @@ -505,7 +505,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR const auto lse = load_tile(lse_dram_window); static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_lse == -numeric::infinity() ? type_convert(0.f) @@ -527,7 +528,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSVR sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp index 5ffa7f8d50..f022fe46d1 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp @@ -58,7 +58,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; - static constexpr bool kHasBias = Problem::kHasBias; + static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad; static constexpr bool kHasDropout = Problem::kHasDropout; @@ -372,13 +372,13 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS clear_tile(st_acc); // Initialize S^T store_tile(q_lds_window, q_block_tile); // LDS write - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads } const auto bias_tile = load_tile(bias_dram_window); // load bias tile - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { __builtin_amdgcn_sched_barrier( 0); // prevent from messing up the order of global loads @@ -413,7 +413,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS } // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { block_sync_lds(); auto bias_shuffle_tmp = make_static_distributed_tensor( @@ -461,7 +461,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS const auto lse = load_tile(lse_dram_window); static const auto get_validated_lse = [](LSEDataType raw_lse) { - if constexpr(kHasBias || FmhaMask::IsMasking) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || + FmhaMask::IsMasking) { return raw_lse == -numeric::infinity() ? type_convert(0.f) @@ -483,7 +484,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) { constexpr auto i_j_idx = make_tuple(idx0, idx1); #if CK_TILE_FMHA_FWD_FAST_EXP2 - if constexpr(kHasBias) + if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index ba840e725b..a013ee3d57 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -644,7 +644,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() { constexpr index_t smem_size_bias = [&]() { - if constexpr(Problem::kHasBias) + if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) return sizeof(typename Problem::BiasDataType) * MakeBiasTLdsBlockDescriptor().get_element_space_size(); else diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index 5ed41d6264..7b787e9f36 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -55,7 +55,7 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr bool kHasBias = Traits::kHasBias; + static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasDropout = Traits::kHasDropout; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index c97073aaf5..a89536e6eb 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -3,7 +3,6 @@ #pragma once -#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" @@ -12,9 +11,13 @@ #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v2_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"