diff --git a/example/ck_tile/18_hstu_attention/CMakeLists.txt b/example/ck_tile/18_hstu_attention/CMakeLists.txt new file mode 100644 index 0000000000..b62b32e14e --- /dev/null +++ b/example/ck_tile/18_hstu_attention/CMakeLists.txt @@ -0,0 +1,21 @@ +set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_HSTU_ATTENTION}") +##file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp) +target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +##target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE hstu_attention_bf16.cpp hstu_attention_fp16.cpp ${INSTANCE_SRCS}) + +set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS) + +list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) + diff --git a/example/ck_tile/18_hstu_attention/README.md b/example/ck_tile/18_hstu_attention/README.md new file mode 100644 index 0000000000..3ce1f27a14 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/README.md @@ -0,0 +1,53 @@ +# HSTU attention operator + + HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`, + `v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following: + + * Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]` + * Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask + as well assequence mask + * Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]` + * Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]` + * Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]` + + +## implementation + + The operator is implemented using a fused kernel in the example: + + * Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed + +## build + + ``` bash + #> mkdir build + #> cd build + #> ../script/cmake-ck-dev.sh .. gfx942 ; use #> rocminfo |grep "gfx" to check your gpu arch + #> make -j tile_example_hstu_attention + ``` + +## test/verify + + ``` bash + #> build/bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -nidx=9 -nhead=4 -hsizeq=64 -hsizev=64 -seqq=13 -seqk=512 -init=u -seed=123 -perf=0 -maskmax=0 + #> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh + ``` + + Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following: + + ``` C++ + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("prec", "fp16", "data type. fp16/bf16") + .insert("b", "12", "batch size") + .insert("nidx", "9", "number of indices for accessing the batches") + .insert("nhead", "4", "number of heads") + .insert("hsizeq", "64", "headdim size of Q/K") + .insert("hsizev", "64", "headdim size of V/O") + .insert("seqq", "13", "length of the sequence dimension of query tensor") + .insert("seqv", "1024", "length of the sequence dimension of key tensor") + .insert("init", "u", "init method for input tensor values, u, uniform random float values, n, normalized random float values") + .insert("seed", "13579", "seed by the uniform or normal distribution generator") + .insert("perf", "0", "weather measure execution time or not") + .insert("maskmax", "0", "used to set mask values to random [0, maskmax), maskmax should in [0, 128], 0 means set all values to 1"); + ``` + diff --git a/example/ck_tile/18_hstu_attention/bool_switch.hpp b/example/ck_tile/18_hstu_attention/bool_switch.hpp new file mode 100644 index 0000000000..22e25d97e7 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/bool_switch.hpp @@ -0,0 +1,46 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#define BOOL_SWITCH(COND1, CONST_NAME1, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + __VA_ARGS__(); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + __VA_ARGS__(); \ + } \ + }() + +#define BOOL_SWITCH_2(COND1, CONST_NAME1, COND2, CONST_NAME2, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \ + } \ + }() + +#define BOOL_SWITCH_3(COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, ...) \ + [&] { \ + if(COND1) \ + { \ + constexpr bool CONST_NAME1 = true; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + else \ + { \ + constexpr bool CONST_NAME1 = false; \ + BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \ + } \ + }() diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp new file mode 100644 index 0000000000..9432819199 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -0,0 +1,259 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "hstu_attention_setting.hpp" +#include "bool_switch.hpp" +#include "reference_hstu_attention.hpp" + +template +std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + using size_type = typename std::vector::size_type; + + os << "["; + for(size_type idx = 0; idx < v.size(); ++idx) + { + if(0 < idx) + { + os << ", "; + } + os << v[idx]; + } + return os << "]"; +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + + // clang-format off + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("prec", "fp16", "data type. fp16/bf16") + .insert("jagged", "0", "q/k/v batched sequence is jagged or not") + .insert("b", "12", "batch size") + .insert("nhead", "4", "number of heads") + .insert("hdim_qk", "64", "headdim size of Q/K") + .insert("hdim_v", "64", "headdim size of V/O") + .insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor") + .insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention") + .insert("causal", "1", "enable causal mask or not") + .insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable") + .insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention") + .insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention") + .insert("seed", "13579", "seed by the uniform or normal distribution generator") + .insert("perf", "0", "weather measure execution time or not"); + // clang-format on + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +static std::vector get_integers_from_string(std::string lengthsStr) +{ + std::vector lengths; + std::size_t pos = 0; + std::size_t new_pos; + + new_pos = lengthsStr.find(',', pos); + while(new_pos != std::string::npos) + { + std::string sliceStr = lengthsStr.substr(pos, new_pos - pos); + + int len = std::stoi(sliceStr); + + lengths.push_back(len); + + pos = new_pos + 1; + new_pos = lengthsStr.find(',', pos); + }; + + std::string sliceStr = lengthsStr.substr(pos); + int len = std::stoi(sliceStr); + + lengths.push_back(len); + + return (lengths); +}; + +// threshold for different dtypes +template +auto get_elimit() +{ + double rtol = 2e-3; + double atol = 2e-3; + + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + bool do_validation = static_cast(arg_parser.get_int("v")); + bool is_jagged = static_cast(arg_parser.get_int("jagged")); + int num_batch = arg_parser.get_int("b"); + int nhead = arg_parser.get_int("nhead"); + int hdim_qk = arg_parser.get_int("hdim_qk"); + int hdim_v = arg_parser.get_int("hdim_v"); + bool use_causal = static_cast(arg_parser.get_int("causal")); + + int max_attn_len = arg_parser.get_int("local_len"); + + bool use_local = (max_attn_len > 0); + + int contextual_seq_len = arg_parser.get_int("context_len"); + int min_full_seq_len = arg_parser.get_int("minfull_len"); + + int seed = arg_parser.get_int("seed"); + + bool measure_perf = static_cast(arg_parser.get_int("perf")); + + (void)do_validation; + (void)measure_perf; + + std::string str_of_targets = arg_parser.get_str("targets"); + std::vector num_targets = get_integers_from_string(str_of_targets); + + std::string str_of_lengths = arg_parser.get_str("seqlen"); + std::vector seq_lengths = get_integers_from_string(str_of_lengths); + + std::vector seq_offsets; + + int seqlen = 0; // means total seq lengths for jagged + + if(is_jagged) + { + assert(num_batch == seq_lengths.size()); + + seq_offsets.push_back(0); + for(size_t i = 0; i < seq_lengths.size(); i++) + { + seqlen += seq_lengths[i]; + seq_offsets.push_back(seqlen); + }; + + if(!num_targets.empty()) + { + assert(num_batch == num_targets.size()); + + for(size_t i = 0; i < seq_lengths.size(); i++) + { + assert(seq_lengths[i] - num_targets[i] >= min_full_seq_len); + assert(seq_lengths[i] - num_targets[i] >= contextual_seq_len); + }; + } + else + { + for(size_t i = 0; i < seq_lengths.size(); i++) + { + assert(seq_lengths[i] >= min_full_seq_len); + assert(seq_lengths[i] >= contextual_seq_len); + }; + }; + } + else + { + assert(1 == seq_lengths.size()); + seqlen = seq_lengths[0]; + + if(!num_targets.empty()) + { + assert(1 == num_targets.size()); + + assert(seqlen - num_targets[0] >= min_full_seq_len); + assert(seqlen - num_targets[0] >= contextual_seq_len); + } + else + { + assert(seqlen >= min_full_seq_len); + assert(seqlen >= contextual_seq_len); + }; + }; + + int batches_for_alloc = is_jagged ? 1 : num_batch; + + ck_tile::HostTensor q_host( + std::array{batches_for_alloc, seqlen, nhead, hdim_qk}); + ck_tile::HostTensor k_host( + std::array{batches_for_alloc, seqlen, nhead, hdim_qk}); + ck_tile::HostTensor v_host( + std::array{batches_for_alloc, seqlen, nhead, hdim_v}); + ck_tile::HostTensor o_host_ref( + std::array{batches_for_alloc, seqlen, nhead, hdim_v}); + + ck_tile::FillNormalDistributionIntegerValue{-2.f, 2.f, seed}(q_host); + ck_tile::FillNormalDistributionIntegerValue{-2.f, 2.f, seed}(k_host); + ck_tile::FillNormalDistributionIntegerValue{-2.f, 2.f, seed}(v_host); + + using GemmAccDataType = typename HSTUAttentionTypeConfig::GemmAccDataType; + using SMComputeDataType = typename HSTUAttentionTypeConfig::SMComputeDataType; + + BOOL_SWITCH_2(use_causal, USE_CAUSAL_, use_local, USE_LOCAL_, [&] { + ck_tile::reference_hstu_attention::Run(q_host, + k_host, + v_host, + o_host_ref, + num_batch, + 1.0f, + seq_offsets, + num_targets, + max_attn_len, + contextual_seq_len, + min_full_seq_len); + }); + return 0; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + { + std::cerr << "Invalid arguments, Failed to parse!" << std::endl; + return -1; + } + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp new file mode 100644 index 0000000000..3225d05651 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp @@ -0,0 +1,24 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +// Type configuration +template +struct HSTUAttentionTypeConfig; + +template <> +struct HSTUAttentionTypeConfig +{ + using GemmAccDataType = float; + using SMComputeDataType = float; +}; + +template <> +struct HSTUAttentionTypeConfig +{ + using GemmAccDataType = float; + using SMComputeDataType = float; +}; diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp new file mode 100644 index 0000000000..d512aedcd3 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -0,0 +1,211 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#include "bool_switch.hpp" + +namespace ck_tile { + +// clang-format off +// Reference implementation of HSTUAttention problem, which does the following from input tensors: +// S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v] +// P[num_batch, num_head, seqlen, seqlen] = Masking(SiLu(S[num_batch, num_head, seqlen, seqlen])) +// O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v] +// The process is very similar to the generic attention, the difference is that SiLu is used rather than Softmax, and hstu masking +// is much more complicated than the lower-triangular + disagonal-window based causal mask +// clang-format on + +template +struct reference_hstu_attention +{ + struct hstu_mask + { + int max_attn_len; + int contextual_seq_len; + int min_full_attn_seq_len; + int max_uih_len; + + hstu_mask(int max_attn_len_, + int contextual_seq_len_, + int min_full_attn_seq_len_, + int max_uih_len_) + { + max_attn_len = max_attn_len_; + contextual_seq_len = contextual_seq_len_; + min_full_attn_seq_len = min_full_attn_seq_len_; + max_uih_len = max_uih_len_; + }; + + bool IsPixelInsideMask(int row, int col) + { + if(row < contextual_seq_len) + return true; + + bool result = false; + if constexpr(use_local) + { + if constexpr(use_causal) + result = (row >= col) && (row - col <= max_attn_len); + else + result = std::abs(row - col) <= max_attn_len; + + if(min_full_attn_seq_len > 0) + result = result || (row >= max_uih_len - min_full_attn_seq_len); + } + else + { + if constexpr(use_causal) + result = (row >= col); + }; + + return result; + }; + }; + + static void Run(const HostTensor& q_batch_seq_nhead_hdim, + const HostTensor& k_batch_seq_nhead_hdim, + const HostTensor& v_batch_seq_nhead_hdim, + HostTensor& o_batch_seq_nhead_hdim, + int num_batch, + float alpha, + std::vector seq_offsets, + std::vector num_targets, // define masking length at the end of token + // sequence to be excluded for attention + int max_attn_len, // define the diagonal local window size + int contextual_seq_len, // define masking length at the begin of query token + // sequence to be included for attention + int min_full_attn_seq_len) // define masking length at the end of query token + // sequence which is included for full attention + { + bool is_jagged = !seq_offsets.empty(); + + if(is_jagged) + { + // check the number of batches + assert(seq_offsets.size() == num_batch + 1); + assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1); + assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1); + assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1); + assert(o_batch_seq_nhead_hdim.get_lengths()[0] == 1); + } + else + { + assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); + assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); + assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); + assert(o_batch_seq_nhead_hdim.get_lengths()[0] == num_batch); + }; + + // check the sequence length + assert(q_batch_seq_nhead_hdim.get_lengths()[1] == k_batch_seq_nhead_hdim.get_lengths()[1]); + assert(q_batch_seq_nhead_hdim.get_lengths()[1] == v_batch_seq_nhead_hdim.get_lengths()[1]); + assert(q_batch_seq_nhead_hdim.get_lengths()[1] == o_batch_seq_nhead_hdim.get_lengths()[1]); + + // check the number of heads + int num_head = q_batch_seq_nhead_hdim.get_lengths()[2]; + assert(num_head == k_batch_seq_nhead_hdim.get_lengths()[2]); + assert(num_head == v_batch_seq_nhead_hdim.get_lengths()[2]); + assert(num_head == o_batch_seq_nhead_hdim.get_lengths()[2]); + + // check the hdim + int hdim_qk = q_batch_seq_nhead_hdim.get_lengths()[3]; + int hdim_v = v_batch_seq_nhead_hdim.get_lengths()[3]; + assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]); + assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]); + + auto silu = [](CompDataType x) { + auto one = ck_tile::type_convert(1.0f); + + auto sigmod_val = one / (one + std::exp(-x)); + + return sigmod_val * x; + }; + + bool has_target = !num_targets.empty(); + + if(has_target) + assert(num_targets.size() == num_batch); + + auto f = [&](auto i_batch, auto i_head) { + int seqlen = is_jagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch]) + : q_batch_seq_nhead_hdim.get_lengths()[1]; + + int max_uih_len = seqlen; + + if(contextual_seq_len > 0) + max_uih_len -= contextual_seq_len - 1; + + if(has_target) + max_uih_len -= num_targets[i_batch]; + + hstu_mask mask{max_attn_len, contextual_seq_len, min_full_attn_seq_len, max_uih_len}; + + // for all rows in the batch + for(int sq = 0; sq < max_uih_len; sq++) + { + std::vector locals; + + // for all cols in the batch + for(int sk = 0; sk < max_uih_len; sk++) + { + if(mask.IsPixelInsideMask(sq, sk)) + { + GemmAccDataType dot_prod = 0.f; + for(int k = 0; k < hdim_qk; k++) + { + InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k); + InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + + dot_prod += ck_tile::type_convert(qreg) * + ck_tile::type_convert(kreg); + } + + locals.push_back(ck_tile::type_convert(dot_prod) * + ck_tile::type_convert(alpha)); + } + else + locals.push_back(ck_tile::type_convert(0.0f)); + }; + + // SiLu element-wise + for(CompDataType& elem : locals) + elem = silu(elem) / ck_tile::type_convert(seqlen); + + // second Gemm + for(int k = 0; k < hdim_v; k++) + { + GemmAccDataType dot_prod = 0.f; + + for(int sk = 0; sk < max_uih_len; sk++) + { + InOutDataType preg = ck_tile::type_convert(locals[sk]); + InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k); + + dot_prod += ck_tile::type_convert(preg) * + ck_tile::type_convert(vreg); + }; + + o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) = + ck_tile::type_convert(dot_prod); + }; + }; + }; + + make_ParallelTensorFunctor(f, num_batch, num_head)(std::thread::hardware_concurrency()); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 7f4ba2ed35..85e751e5dd 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant) add_subdirectory(15_fused_moe) add_subdirectory(16_batched_gemm) add_subdirectory(17_grouped_gemm) +add_subdirectory(18_hstu_attention) add_subdirectory(35_batched_transpose)