mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Initial reference implementation of hstu attention
This commit is contained in:
21
example/ck_tile/18_hstu_attention/CMakeLists.txt
Normal file
21
example/ck_tile/18_hstu_attention/CMakeLists.txt
Normal file
@@ -0,0 +1,21 @@
|
||||
set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
|
||||
##file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
|
||||
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
##target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE hstu_attention_bf16.cpp hstu_attention_fp16.cpp ${INSTANCE_SRCS})
|
||||
|
||||
set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS)
|
||||
|
||||
list(APPEND EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
|
||||
target_compile_options(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
|
||||
53
example/ck_tile/18_hstu_attention/README.md
Normal file
53
example/ck_tile/18_hstu_attention/README.md
Normal file
@@ -0,0 +1,53 @@
|
||||
# HSTU attention operator
|
||||
|
||||
HSTU-attention operator is an operator which takes tensor `q: [batches, seqlen, nhead, hdim_qk]`, `k: [batches, seqlen, nhead, hdim_qk`,
|
||||
`v: [batches, seqlen, nhead, hdim_v]` and some parameters for defining the functional masking as inputs, and do the following:
|
||||
|
||||
* Multiply `q: [batches, seqlen, nhead, hdim_qk]` with `k: [batches, seqlen, nhead, hdim_k]` to get temporary tensor `s: [batches, nhead, seqlen, seqlen]`
|
||||
* Update `s` by filtering its values according to a special functional mask, which includes the logics of lower-triangular and diagonal window causal mask
|
||||
as well assequence mask
|
||||
* Do element-wise SiLu on the `lower seqlen` dimension of `s` to get temporary tensor `p: [batches, nhead, seqlen, seqlen]`
|
||||
* Multiply `p : [batches, nhead, seqlen, seqlen]` with `v: [batches, seqlen, nhead, hdim_v]` to get final output `o: [batches, seqlen_q, nhead, headsz_v]`
|
||||
* Jagged inputs are also supported, where each batch has separate seqlen defined by the `sequence_offsets[]`
|
||||
|
||||
|
||||
## implementation
|
||||
|
||||
The operator is implemented using a fused kernel in the example:
|
||||
|
||||
* Tensor S and Tensor P only exist in VGPRs as per-workgroup tiles, no global memory access is needed
|
||||
|
||||
## build
|
||||
|
||||
``` bash
|
||||
#> mkdir build
|
||||
#> cd build
|
||||
#> ../script/cmake-ck-dev.sh .. gfx942 ; use #> rocminfo |grep "gfx" to check your gpu arch
|
||||
#> make -j tile_example_hstu_attention
|
||||
```
|
||||
|
||||
## test/verify
|
||||
|
||||
``` bash
|
||||
#> build/bin/tile_example_hstu_attention -v=1 -prec=fp16 -b=10 -nidx=9 -nhead=4 -hsizeq=64 -hsizev=64 -seqq=13 -seqk=512 -init=u -seed=123 -perf=0 -maskmax=0
|
||||
#> . example/ck_tile/07_hstu_attention/test_hstu_attention.sh
|
||||
```
|
||||
|
||||
Check the example file `example_hstu_attention.cpp` for an understanding of the command-line arguments. Which is like the following:
|
||||
|
||||
``` C++
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("b", "12", "batch size")
|
||||
.insert("nidx", "9", "number of indices for accessing the batches")
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hsizeq", "64", "headdim size of Q/K")
|
||||
.insert("hsizev", "64", "headdim size of V/O")
|
||||
.insert("seqq", "13", "length of the sequence dimension of query tensor")
|
||||
.insert("seqv", "1024", "length of the sequence dimension of key tensor")
|
||||
.insert("init", "u", "init method for input tensor values, u, uniform random float values, n, normalized random float values")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("perf", "0", "weather measure execution time or not")
|
||||
.insert("maskmax", "0", "used to set mask values to random [0, maskmax), maskmax should in [0, 128], 0 means set all values to 1");
|
||||
```
|
||||
|
||||
46
example/ck_tile/18_hstu_attention/bool_switch.hpp
Normal file
46
example/ck_tile/18_hstu_attention/bool_switch.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#define BOOL_SWITCH(COND1, CONST_NAME1, ...) \
|
||||
[&] { \
|
||||
if(COND1) \
|
||||
{ \
|
||||
constexpr bool CONST_NAME1 = true; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
constexpr bool CONST_NAME1 = false; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define BOOL_SWITCH_2(COND1, CONST_NAME1, COND2, CONST_NAME2, ...) \
|
||||
[&] { \
|
||||
if(COND1) \
|
||||
{ \
|
||||
constexpr bool CONST_NAME1 = true; \
|
||||
BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
constexpr bool CONST_NAME1 = false; \
|
||||
BOOL_SWITCH(COND2, CONST_NAME2, ##__VA_ARGS__); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define BOOL_SWITCH_3(COND1, CONST_NAME1, COND2, CONST_NAME2, COND3, CONST_NAME3, ...) \
|
||||
[&] { \
|
||||
if(COND1) \
|
||||
{ \
|
||||
constexpr bool CONST_NAME1 = true; \
|
||||
BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
constexpr bool CONST_NAME1 = false; \
|
||||
BOOL_SWITCH_2(COND2, CONST_NAME2, COND3, CONST_NAME3, ##__VA_ARGS__); \
|
||||
} \
|
||||
}()
|
||||
259
example/ck_tile/18_hstu_attention/example_hstu_attention.cpp
Normal file
259
example/ck_tile/18_hstu_attention/example_hstu_attention.cpp
Normal file
@@ -0,0 +1,259 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include <random>
|
||||
#include <cassert>
|
||||
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
#include <ck_tile/host/fill.hpp>
|
||||
#include <ck_tile/host/device_memory.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/host/arg_parser.hpp>
|
||||
#include <ck_tile/host/hip_check_error.hpp>
|
||||
#include <ck_tile/host/check_err.hpp>
|
||||
#include <ck_tile/host/timer.hpp>
|
||||
|
||||
#include "hstu_attention_setting.hpp"
|
||||
#include "bool_switch.hpp"
|
||||
#include "reference_hstu_attention.hpp"
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
|
||||
// clang-format off
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16")
|
||||
.insert("jagged", "0", "q/k/v batched sequence is jagged or not")
|
||||
.insert("b", "12", "batch size")
|
||||
.insert("nhead", "4", "number of heads")
|
||||
.insert("hdim_qk", "64", "headdim size of Q/K")
|
||||
.insert("hdim_v", "64", "headdim size of V/O")
|
||||
.insert("seqlen", "400", "seqlen of single or all batches for query and key/value tensor")
|
||||
.insert("targets", "16", "sequence length at the end of query/key token sequence that should be excluded from attention")
|
||||
.insert("causal", "1", "enable causal mask or not")
|
||||
.insert("local_len", "5", "length of the diagonal window for enabling masking, value 0 to disable")
|
||||
.insert("context_len", "6", "sequence length at the begin of the query sequence the should be included for attention")
|
||||
.insert("minfull_len", "6", "sequence length at the end of the query sequence that should be included for attention")
|
||||
.insert("seed", "13579", "seed by the uniform or normal distribution generator")
|
||||
.insert("perf", "0", "weather measure execution time or not");
|
||||
// clang-format on
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
static std::vector<int> get_integers_from_string(std::string lengthsStr)
|
||||
{
|
||||
std::vector<int> lengths;
|
||||
std::size_t pos = 0;
|
||||
std::size_t new_pos;
|
||||
|
||||
new_pos = lengthsStr.find(',', pos);
|
||||
while(new_pos != std::string::npos)
|
||||
{
|
||||
std::string sliceStr = lengthsStr.substr(pos, new_pos - pos);
|
||||
|
||||
int len = std::stoi(sliceStr);
|
||||
|
||||
lengths.push_back(len);
|
||||
|
||||
pos = new_pos + 1;
|
||||
new_pos = lengthsStr.find(',', pos);
|
||||
};
|
||||
|
||||
std::string sliceStr = lengthsStr.substr(pos);
|
||||
int len = std::stoi(sliceStr);
|
||||
|
||||
lengths.push_back(len);
|
||||
|
||||
return (lengths);
|
||||
};
|
||||
|
||||
// threshold for different dtypes
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 2e-3;
|
||||
double atol = 2e-3;
|
||||
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <typename InOutDataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
|
||||
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
|
||||
int num_batch = arg_parser.get_int("b");
|
||||
int nhead = arg_parser.get_int("nhead");
|
||||
int hdim_qk = arg_parser.get_int("hdim_qk");
|
||||
int hdim_v = arg_parser.get_int("hdim_v");
|
||||
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
|
||||
|
||||
int max_attn_len = arg_parser.get_int("local_len");
|
||||
|
||||
bool use_local = (max_attn_len > 0);
|
||||
|
||||
int contextual_seq_len = arg_parser.get_int("context_len");
|
||||
int min_full_seq_len = arg_parser.get_int("minfull_len");
|
||||
|
||||
int seed = arg_parser.get_int("seed");
|
||||
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
|
||||
(void)do_validation;
|
||||
(void)measure_perf;
|
||||
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
|
||||
std::string str_of_lengths = arg_parser.get_str("seqlen");
|
||||
std::vector<int> seq_lengths = get_integers_from_string(str_of_lengths);
|
||||
|
||||
std::vector<int> seq_offsets;
|
||||
|
||||
int seqlen = 0; // means total seq lengths for jagged
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
assert(num_batch == seq_lengths.size());
|
||||
|
||||
seq_offsets.push_back(0);
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
seqlen += seq_lengths[i];
|
||||
seq_offsets.push_back(seqlen);
|
||||
};
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
assert(num_batch == num_targets.size());
|
||||
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
assert(seq_lengths[i] - num_targets[i] >= min_full_seq_len);
|
||||
assert(seq_lengths[i] - num_targets[i] >= contextual_seq_len);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
assert(seq_lengths[i] >= min_full_seq_len);
|
||||
assert(seq_lengths[i] >= contextual_seq_len);
|
||||
};
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(1 == seq_lengths.size());
|
||||
seqlen = seq_lengths[0];
|
||||
|
||||
if(!num_targets.empty())
|
||||
{
|
||||
assert(1 == num_targets.size());
|
||||
|
||||
assert(seqlen - num_targets[0] >= min_full_seq_len);
|
||||
assert(seqlen - num_targets[0] >= contextual_seq_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(seqlen >= min_full_seq_len);
|
||||
assert(seqlen >= contextual_seq_len);
|
||||
};
|
||||
};
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> k_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> v_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_v});
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_v});
|
||||
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(v_host);
|
||||
|
||||
using GemmAccDataType = typename HSTUAttentionTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using SMComputeDataType = typename HSTUAttentionTypeConfig<InOutDataType>::SMComputeDataType;
|
||||
|
||||
BOOL_SWITCH_2(use_causal, USE_CAUSAL_, use_local, USE_LOCAL_, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
SMComputeDataType,
|
||||
USE_CAUSAL_,
|
||||
USE_LOCAL_>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
max_attn_len,
|
||||
contextual_seq_len,
|
||||
min_full_seq_len);
|
||||
});
|
||||
return 0;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "Invalid arguments, Failed to parse!" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run<ck_tile::bf16_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
24
example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp
Normal file
24
example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
// Type configuration
|
||||
template <typename DataType>
|
||||
struct HSTUAttentionTypeConfig;
|
||||
|
||||
template <>
|
||||
struct HSTUAttentionTypeConfig<ck_tile::fp16_t>
|
||||
{
|
||||
using GemmAccDataType = float;
|
||||
using SMComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HSTUAttentionTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using GemmAccDataType = float;
|
||||
using SMComputeDataType = float;
|
||||
};
|
||||
211
example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp
Normal file
211
example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp
Normal file
@@ -0,0 +1,211 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
|
||||
#include "bool_switch.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// clang-format off
|
||||
// Reference implementation of HSTUAttention problem, which does the following from input tensors:
|
||||
// S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v]
|
||||
// P[num_batch, num_head, seqlen, seqlen] = Masking(SiLu(S[num_batch, num_head, seqlen, seqlen]))
|
||||
// O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v]
|
||||
// The process is very similar to the generic attention, the difference is that SiLu is used rather than Softmax, and hstu masking
|
||||
// is much more complicated than the lower-triangular + disagonal-window based causal mask
|
||||
// clang-format on
|
||||
|
||||
template <typename InOutDataType,
|
||||
typename GemmAccDataType,
|
||||
typename CompDataType,
|
||||
bool use_causal,
|
||||
bool use_local>
|
||||
struct reference_hstu_attention
|
||||
{
|
||||
struct hstu_mask
|
||||
{
|
||||
int max_attn_len;
|
||||
int contextual_seq_len;
|
||||
int min_full_attn_seq_len;
|
||||
int max_uih_len;
|
||||
|
||||
hstu_mask(int max_attn_len_,
|
||||
int contextual_seq_len_,
|
||||
int min_full_attn_seq_len_,
|
||||
int max_uih_len_)
|
||||
{
|
||||
max_attn_len = max_attn_len_;
|
||||
contextual_seq_len = contextual_seq_len_;
|
||||
min_full_attn_seq_len = min_full_attn_seq_len_;
|
||||
max_uih_len = max_uih_len_;
|
||||
};
|
||||
|
||||
bool IsPixelInsideMask(int row, int col)
|
||||
{
|
||||
if(row < contextual_seq_len)
|
||||
return true;
|
||||
|
||||
bool result = false;
|
||||
if constexpr(use_local)
|
||||
{
|
||||
if constexpr(use_causal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = std::abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seq_len > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seq_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(use_causal)
|
||||
result = (row >= col);
|
||||
};
|
||||
|
||||
return result;
|
||||
};
|
||||
};
|
||||
|
||||
static void Run(const HostTensor<InOutDataType>& q_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& k_batch_seq_nhead_hdim,
|
||||
const HostTensor<InOutDataType>& v_batch_seq_nhead_hdim,
|
||||
HostTensor<InOutDataType>& o_batch_seq_nhead_hdim,
|
||||
int num_batch,
|
||||
float alpha,
|
||||
std::vector<int> seq_offsets,
|
||||
std::vector<int> num_targets, // define masking length at the end of token
|
||||
// sequence to be excluded for attention
|
||||
int max_attn_len, // define the diagonal local window size
|
||||
int contextual_seq_len, // define masking length at the begin of query token
|
||||
// sequence to be included for attention
|
||||
int min_full_attn_seq_len) // define masking length at the end of query token
|
||||
// sequence which is included for full attention
|
||||
{
|
||||
bool is_jagged = !seq_offsets.empty();
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
// check the number of batches
|
||||
assert(seq_offsets.size() == num_batch + 1);
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
assert(o_batch_seq_nhead_hdim.get_lengths()[0] == 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(k_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(v_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
assert(o_batch_seq_nhead_hdim.get_lengths()[0] == num_batch);
|
||||
};
|
||||
|
||||
// check the sequence length
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[1] == k_batch_seq_nhead_hdim.get_lengths()[1]);
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[1] == v_batch_seq_nhead_hdim.get_lengths()[1]);
|
||||
assert(q_batch_seq_nhead_hdim.get_lengths()[1] == o_batch_seq_nhead_hdim.get_lengths()[1]);
|
||||
|
||||
// check the number of heads
|
||||
int num_head = q_batch_seq_nhead_hdim.get_lengths()[2];
|
||||
assert(num_head == k_batch_seq_nhead_hdim.get_lengths()[2]);
|
||||
assert(num_head == v_batch_seq_nhead_hdim.get_lengths()[2]);
|
||||
assert(num_head == o_batch_seq_nhead_hdim.get_lengths()[2]);
|
||||
|
||||
// check the hdim
|
||||
int hdim_qk = q_batch_seq_nhead_hdim.get_lengths()[3];
|
||||
int hdim_v = v_batch_seq_nhead_hdim.get_lengths()[3];
|
||||
assert(hdim_qk == k_batch_seq_nhead_hdim.get_lengths()[3]);
|
||||
assert(hdim_v == o_batch_seq_nhead_hdim.get_lengths()[3]);
|
||||
|
||||
auto silu = [](CompDataType x) {
|
||||
auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
auto sigmod_val = one / (one + std::exp(-x));
|
||||
|
||||
return sigmod_val * x;
|
||||
};
|
||||
|
||||
bool has_target = !num_targets.empty();
|
||||
|
||||
if(has_target)
|
||||
assert(num_targets.size() == num_batch);
|
||||
|
||||
auto f = [&](auto i_batch, auto i_head) {
|
||||
int seqlen = is_jagged ? (seq_offsets[i_batch + 1] - seq_offsets[i_batch])
|
||||
: q_batch_seq_nhead_hdim.get_lengths()[1];
|
||||
|
||||
int max_uih_len = seqlen;
|
||||
|
||||
if(contextual_seq_len > 0)
|
||||
max_uih_len -= contextual_seq_len - 1;
|
||||
|
||||
if(has_target)
|
||||
max_uih_len -= num_targets[i_batch];
|
||||
|
||||
hstu_mask mask{max_attn_len, contextual_seq_len, min_full_attn_seq_len, max_uih_len};
|
||||
|
||||
// for all rows in the batch
|
||||
for(int sq = 0; sq < max_uih_len; sq++)
|
||||
{
|
||||
std::vector<CompDataType> locals;
|
||||
|
||||
// for all cols in the batch
|
||||
for(int sk = 0; sk < max_uih_len; sk++)
|
||||
{
|
||||
if(mask.IsPixelInsideMask(sq, sk))
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
for(int k = 0; k < hdim_qk; k++)
|
||||
{
|
||||
InOutDataType qreg = q_batch_seq_nhead_hdim(i_batch, sq, i_head, k);
|
||||
InOutDataType kreg = k_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(qreg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(kreg);
|
||||
}
|
||||
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(dot_prod) *
|
||||
ck_tile::type_convert<CompDataType>(alpha));
|
||||
}
|
||||
else
|
||||
locals.push_back(ck_tile::type_convert<CompDataType>(0.0f));
|
||||
};
|
||||
|
||||
// SiLu element-wise
|
||||
for(CompDataType& elem : locals)
|
||||
elem = silu(elem) / ck_tile::type_convert<CompDataType>(seqlen);
|
||||
|
||||
// second Gemm
|
||||
for(int k = 0; k < hdim_v; k++)
|
||||
{
|
||||
GemmAccDataType dot_prod = 0.f;
|
||||
|
||||
for(int sk = 0; sk < max_uih_len; sk++)
|
||||
{
|
||||
InOutDataType preg = ck_tile::type_convert<InOutDataType>(locals[sk]);
|
||||
InOutDataType vreg = v_batch_seq_nhead_hdim(i_batch, sk, i_head, k);
|
||||
|
||||
dot_prod += ck_tile::type_convert<GemmAccDataType>(preg) *
|
||||
ck_tile::type_convert<GemmAccDataType>(vreg);
|
||||
};
|
||||
|
||||
o_batch_seq_nhead_hdim(i_batch, sq, i_head, k) =
|
||||
ck_tile::type_convert<InOutDataType>(dot_prod);
|
||||
};
|
||||
};
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, num_batch, num_head)(std::thread::hardware_concurrency());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -17,4 +17,5 @@ add_subdirectory(14_moe_smoothquant)
|
||||
add_subdirectory(15_fused_moe)
|
||||
add_subdirectory(16_batched_gemm)
|
||||
add_subdirectory(17_grouped_gemm)
|
||||
add_subdirectory(18_hstu_attention)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
|
||||
Reference in New Issue
Block a user