Initial reference implementation of hstu attention

This commit is contained in:
Qianfeng Zhang
2025-03-28 16:26:43 +00:00
parent 441343a23d
commit 4a0fc292d0
7 changed files with 615 additions and 0 deletions

View File

@@ -0,0 +1,21 @@
set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
##file(GLOB INSTANCE_SRCS instances/*.cpp)
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
##target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE hstu_attention_bf16.cpp hstu_attention_fp16.cpp ${INSTANCE_SRCS})
set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS)
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS})
# TODO: we have to turn off this global prop, otherwise the progress bar generated
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
# however, this property may affect global
# TODO: consider codegen a makefile by us
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

View File

@@ -0,0 +1,53 @@
# HSTU attention operator
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
as well assequence mask
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
## implementation
The operator is implemented using a fused kernel in the example:
* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed
## build
``` bash
#> mkdir build
#> cd build
#> ../script/cmake-ck-dev.sh .. gfx942 ; use #> rocminfo |grep "gfx" to check your gpu arch
#> make -j tile_example_hstu_attention
```
## test/verify
``` bash
#> build/bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -nidx=9 -nhead=4 -hsizeq=64 -hsizev=64 -seqq=13 -seqk=512 -init=u -seed=123 -perf=0 -maskmax=0
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
```
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
``` C++
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("b", "12", "batch size")
.insert("nidx", "9", "number of indices for accessing the batches")
.insert("nhead", "4", "number of heads")
.insert("hsizeq", "64", "headdim size of Q/K")
.insert("hsizev", "64", "headdim size of V/O")
.insert("seqq", "13", "length of the sequence dimension of query tensor")
.insert("seqv", "1024", "length of the sequence dimension of key tensor")
.insert("init", "u", "init method for input tensor values, u, uniform random float values, n, normalized random float values")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("perf", "0", "weather measure execution time or not")
.insert("maskmax", "0", "used to set mask values to random [0, maskmax), maskmax should in [0, 128], 0 means set all values to 1");
```

View File

@@ -0,0 +1,46 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#define BOOL_SWITCH(COND1, CONST_NAME1, ...) \
[&] { \
if(COND1) \
{ \
constexpr bool CONST_NAME1 = true; \
__VA_ARGS__(); \
} \
else \
{ \
constexpr bool CONST_NAME1 = false; \
__VA_ARGS__(); \
} \
}()
#define BOOL_SWITCH_2(COND1, CONST_NAME1, COND2, CONST_NAME2, ...) \
[&] { \
if(COND1) \
{ \
constexpr bool CONST_NAME1 = true; \
BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \
} \
else \
{ \
constexpr bool CONST_NAME1 = false; \
BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \
} \
}()
#define BOOL_SWITCH_3(COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, ...) \
[&] { \
if(COND1) \
{ \
constexpr bool CONST_NAME1 = true; \
BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \
} \
else \
{ \
constexpr bool CONST_NAME1 = false; \
BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \
} \
}()

View File

@@ -0,0 +1,259 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <array>
#include <cstring>
#include <functional>
#include <numeric>
#include <ostream>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include <random>
#include <cassert>
#include <ck_tile/host/host_tensor.hpp>
#include <ck_tile/host/fill.hpp>
#include <ck_tile/host/device_memory.hpp>
#include <ck_tile/host/stream_config.hpp>
#include <ck_tile/host/arg_parser.hpp>
#include <ck_tile/host/hip_check_error.hpp>
#include <ck_tile/host/check_err.hpp>
#include <ck_tile/host/timer.hpp>
#include "hstu_attention_setting.hpp"
#include "bool_switch.hpp"
#include "reference_hstu_attention.hpp"
template <typename T>
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
{
using size_type = typename std::vector<T>::size_type;
os << "[";
for(size_type idx = 0; idx < v.size(); ++idx)
{
if(0 < idx)
{
os << ", ";
}
os << v[idx];
}
return os << "]";
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
// clang-format off
arg_parser.insert("v", "1", "weather do CPU validation or not")
.insert("prec", "fp16", "data type. fp16/bf16")
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
.insert("b", "12", "batch size")
.insert("nhead", "4", "number of heads")
.insert("hdim_qk", "64", "headdim size of Q/K")
.insert("hdim_v", "64", "headdim size of V/O")
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
.insert("causal", "1", "enable causal mask or not")
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
.insert("perf", "0", "weather measure execution time or not");
// clang-format on
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
static std::vector<int> get_integers_from_string(std::string lengthsStr)
{
std::vector<int> lengths;
std::size_t pos = 0;
std::size_t new_pos;
new_pos = lengthsStr.find(',', pos);
while(new_pos != std::string::npos)
{
std::string sliceStr = lengthsStr.substr(pos, new_pos - pos);
int len = std::stoi(sliceStr);
lengths.push_back(len);
pos = new_pos + 1;
new_pos = lengthsStr.find(',', pos);
};
std::string sliceStr = lengthsStr.substr(pos);
int len = std::stoi(sliceStr);
lengths.push_back(len);
return (lengths);
};
// threshold for different dtypes
template <typename DataType>
auto get_elimit()
{
double rtol = 2e-3;
double atol = 2e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <typename InOutDataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
int num_batch = arg_parser.get_int("b");
int nhead = arg_parser.get_int("nhead");
int hdim_qk = arg_parser.get_int("hdim_qk");
int hdim_v = arg_parser.get_int("hdim_v");
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
int max_attn_len = arg_parser.get_int("local_len");
bool use_local = (max_attn_len > 0);
int contextual_seq_len = arg_parser.get_int("context_len");
int min_full_seq_len = arg_parser.get_int("minfull_len");
int seed = arg_parser.get_int("seed");
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
(void)do_validation;
(void)measure_perf;
std::string str_of_targets = arg_parser.get_str("targets");
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
std::string str_of_lengths = arg_parser.get_str("seqlen");
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
std::vector<int> seq_offsets;
int seqlen = 0; // means total seq lengths for jagged
if(is_jagged)
{
assert(num_batch == seq_lengths.size());
seq_offsets.push_back(0);
for(size_t i = 0; i < seq_lengths.size(); i++)
{
seqlen += seq_lengths[i];
seq_offsets.push_back(seqlen);
};
if(!num_targets.empty())
{
assert(num_batch == num_targets.size());
for(size_t i = 0; i < seq_lengths.size(); i++)
{
assert(seq_lengths[i] - num_targets[i] >= min_full_seq_len);
assert(seq_lengths[i] - num_targets[i] >= contextual_seq_len);
};
}
else
{
for(size_t i = 0; i < seq_lengths.size(); i++)
{
assert(seq_lengths[i] >= min_full_seq_len);
assert(seq_lengths[i] >= contextual_seq_len);
};
};
}
else
{
assert(1 == seq_lengths.size());
seqlen = seq_lengths[0];
if(!num_targets.empty())
{
assert(1 == num_targets.size());
assert(seqlen - num_targets[0] >= min_full_seq_len);
assert(seqlen - num_targets[0] >= contextual_seq_len);
}
else
{
assert(seqlen >= min_full_seq_len);
assert(seqlen >= contextual_seq_len);
};
};
int batches_for_alloc = is_jagged ? 1 : num_batch;
ck_tile::HostTensor<InOutDataType> q_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_qk});
ck_tile::HostTensor<InOutDataType> k_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_qk});
ck_tile::HostTensor<InOutDataType> v_host(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_v});
ck_tile::HostTensor<InOutDataType> o_host_ref(
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_v});
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(q_host);
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(k_host);
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(v_host);
using GemmAccDataType = typename HSTUAttentionTypeConfig<InOutDataType>::GemmAccDataType;
using SMComputeDataType = typename HSTUAttentionTypeConfig<InOutDataType>::SMComputeDataType;
BOOL_SWITCH_2(use_causal, USE_CAUSAL_, use_local, USE_LOCAL_, [&] {
ck_tile::reference_hstu_attention<InOutDataType,
GemmAccDataType,
SMComputeDataType,
USE_CAUSAL_,
USE_LOCAL_>::Run(q_host,
k_host,
v_host,
o_host_ref,
num_batch,
1.0f,
seq_offsets,
num_targets,
max_attn_len,
contextual_seq_len,
min_full_seq_len);
});
return 0;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
{
std::cerr << "Invalid arguments, Failed to parse!" << std::endl;
return -1;
}
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
else if(data_type == "bf16")
{
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -0,0 +1,24 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ck_tile/core.hpp>
// Type configuration
template <typename DataType>
struct HSTUAttentionTypeConfig;
template <>
struct HSTUAttentionTypeConfig<ck_tile::fp16_t>
{
using GemmAccDataType = float;
using SMComputeDataType = float;
};
template <>
struct HSTUAttentionTypeConfig<ck_tile::bf16_t>
{
using GemmAccDataType = float;
using SMComputeDataType = float;
};

View File

@@ -0,0 +1,211 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <mutex>
#include <cassert>
#include <cmath>
#include <ck_tile/core.hpp>
#include <ck_tile/host/host_tensor.hpp>
#include "bool_switch.hpp"
namespace ck_tile {
// clang-format off
// Reference implementation of HSTUAttention problem, which does the following from input tensors:
// S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v]
// P[num_batch, num_head, seqlen, seqlen] = Masking(SiLu(S[num_batch, num_head, seqlen, seqlen]))
// O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v]
// The process is very similar to the generic attention, the difference is that SiLu is used rather than Softmax, and hstu masking
// is much more complicated than the lower-triangular + disagonal-window based causal mask
// clang-format on
template <typename InOutDataType,
typename GemmAccDataType,
typename CompDataType,
bool use_causal,
bool use_local>
struct reference_hstu_attention
{
struct hstu_mask
{
int max_attn_len;
int contextual_seq_len;
int min_full_attn_seq_len;
int max_uih_len;
hstu_mask(int max_attn_len_,
int contextual_seq_len_,
int min_full_attn_seq_len_,
int max_uih_len_)
{
max_attn_len = max_attn_len_;
contextual_seq_len = contextual_seq_len_;
min_full_attn_seq_len = min_full_attn_seq_len_;
max_uih_len = max_uih_len_;
};
bool IsPixelInsideMask(int row, int col)
{
if(row < contextual_seq_len)
return true;
bool result = false;
if constexpr(use_local)
{
if constexpr(use_causal)
result = (row >= col) && (row - col <= max_attn_len);
else
result = std::abs(row - col) <= max_attn_len;
if(min_full_attn_seq_len > 0)
result = result || (row >= max_uih_len - min_full_attn_seq_len);
}
else
{
if constexpr(use_causal)
result = (row >= col);
};
return result;
};
};
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
int num_batch,
float alpha,
std::vector<int> seq_offsets,
std::vector<int> num_targets, // define masking length at the end of token
// sequence to be excluded for attention
int max_attn_len, // define the diagonal local window size
int contextual_seq_len, // define masking length at the begin of query token
// sequence to be included for attention
int min_full_attn_seq_len) // define masking length at the end of query token
// sequence which is included for full attention
{
bool is_jagged = !seq_offsets.empty();
if(is_jagged)
{
// check the number of batches
assert(seq_offsets.size() == num_batch + 1);
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1);
assert(o_batch_seq_nhead_hdim.get_lengths()[0] == 1);
}
else
{
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
assert(o_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
};
// check the sequence length
assert(q_batch_seq_nhead_hdim.get_lengths()[1] == k_batch_seq_nhead_hdim.get_lengths()[1]);
assert(q_batch_seq_nhead_hdim.get_lengths()[1] == v_batch_seq_nhead_hdim.get_lengths()[1]);
assert(q_batch_seq_nhead_hdim.get_lengths()[1] == o_batch_seq_nhead_hdim.get_lengths()[1]);
// check the number of heads
int num_head = q_batch_seq_nhead_hdim.get_lengths()[2];
assert(num_head == k_batch_seq_nhead_hdim.get_lengths()[2]);
assert(num_head == v_batch_seq_nhead_hdim.get_lengths()[2]);
assert(num_head == o_batch_seq_nhead_hdim.get_lengths()[2]);
// check the hdim
int hdim_qk = q_batch_seq_nhead_hdim.get_lengths()[3];
int hdim_v = v_batch_seq_nhead_hdim.get_lengths()[3];
assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]);
assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]);
auto silu = [](CompDataType x) {
auto one = ck_tile::type_convert<CompDataType>(1.0f);
auto sigmod_val = one / (one + std::exp(-x));
return sigmod_val * x;
};
bool has_target = !num_targets.empty();
if(has_target)
assert(num_targets.size() == num_batch);
auto f = [&](auto i_batch, auto i_head) {
int seqlen = is_jagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
: q_batch_seq_nhead_hdim.get_lengths()[1];
int max_uih_len = seqlen;
if(contextual_seq_len > 0)
max_uih_len -= contextual_seq_len - 1;
if(has_target)
max_uih_len -= num_targets[i_batch];
hstu_mask mask{max_attn_len, contextual_seq_len, min_full_attn_seq_len, max_uih_len};
// for all rows in the batch
for(int sq = 0; sq < max_uih_len; sq++)
{
std::vector<CompDataType> locals;
// for all cols in the batch
for(int sk = 0; sk < max_uih_len; sk++)
{
if(mask.IsPixelInsideMask(sq, sk))
{
GemmAccDataType dot_prod = 0.f;
for(int k = 0; k < hdim_qk; k++)
{
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
ck_tile::type_convert<GemmAccDataType>(kreg);
}
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
ck_tile::type_convert<CompDataType>(alpha));
}
else
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
};
// SiLu element-wise
for(CompDataType& elem : locals)
elem = silu(elem) / ck_tile::type_convert<CompDataType>(seqlen);
// second Gemm
for(int k = 0; k < hdim_v; k++)
{
GemmAccDataType dot_prod = 0.f;
for(int sk = 0; sk < max_uih_len; sk++)
{
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
ck_tile::type_convert<GemmAccDataType>(vreg);
};
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
ck_tile::type_convert<InOutDataType>(dot_prod);
};
};
};
make_ParallelTensorFunctor(f, num_batch, num_head)(std::thread::hardware_concurrency());
}
};
} // namespace ck_tile

View File

@@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory(15_fused_moe)
add_subdirectory(16_batched_gemm)
add_subdirectory(17_grouped_gemm)
add_subdirectory(18_hstu_attention)
add_subdirectory(35_batched_transpose)