mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
Add hstu attention kernel implementation, instances and interfaces (building succeeded)
This commit is contained in:
@@ -2,10 +2,11 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention")
|
||||
# not using add_example_executable() to add this target, since we don't want this to have
|
||||
# to be included in "make all/install/check"
|
||||
message("adding example ${EXAMPLE_HSTU_ATTENTION}")
|
||||
##file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
set(INTERFACES_SRCS hstu_attention_jagged_forward_bf16.cpp hstu_attention_jagged_forward_fp16.cpp hstu_attention_batched_forward_bf16.cpp hstu_attention_batched_forward_fp16.cpp)
|
||||
add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp)
|
||||
target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
##target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE hstu_attention_bf16.cpp hstu_attention_fp16.cpp ${INSTANCE_SRCS})
|
||||
target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS})
|
||||
|
||||
set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS)
|
||||
|
||||
|
||||
@@ -22,10 +22,16 @@
|
||||
#include <ck_tile/host/check_err.hpp>
|
||||
#include <ck_tile/host/timer.hpp>
|
||||
|
||||
#include "hstu_attention_setting.hpp"
|
||||
#include "bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "reference_hstu_attention.hpp"
|
||||
|
||||
extern void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
@@ -120,25 +126,21 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
bool do_validation = static_cast<bool>(arg_parser.get_int("v"));
|
||||
bool is_jagged = static_cast<bool>(arg_parser.get_int("jagged"));
|
||||
int num_batch = arg_parser.get_int("b");
|
||||
int nhead = arg_parser.get_int("nhead");
|
||||
int num_head = arg_parser.get_int("nhead");
|
||||
int hdim_qk = arg_parser.get_int("hdim_qk");
|
||||
int hdim_v = arg_parser.get_int("hdim_v");
|
||||
bool use_causal = static_cast<bool>(arg_parser.get_int("causal"));
|
||||
|
||||
int max_attn_len = arg_parser.get_int("local_len");
|
||||
int window_size = arg_parser.get_int("local_len");
|
||||
|
||||
bool use_local = (max_attn_len > 0);
|
||||
bool use_local = (window_size > 0);
|
||||
|
||||
int contextual_seq_len = arg_parser.get_int("context_len");
|
||||
int min_full_seq_len = arg_parser.get_int("minfull_len");
|
||||
|
||||
int seed = arg_parser.get_int("seed");
|
||||
int contextual_seqlen = arg_parser.get_int("context_len");
|
||||
int min_full_attn_seqlen = arg_parser.get_int("minfull_len");
|
||||
|
||||
int seed = arg_parser.get_int("seed");
|
||||
bool measure_perf = static_cast<bool>(arg_parser.get_int("perf"));
|
||||
|
||||
(void)do_validation;
|
||||
(void)measure_perf;
|
||||
|
||||
std::string str_of_targets = arg_parser.get_str("targets");
|
||||
std::vector<int> num_targets = get_integers_from_string(str_of_targets);
|
||||
|
||||
@@ -147,7 +149,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::vector<int> seq_offsets;
|
||||
|
||||
int seqlen = 0; // means total seq lengths for jagged
|
||||
int seqlen = 0; // means total seq lengths for jagged
|
||||
int max_seqlen = 0;
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
@@ -156,6 +159,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
seq_offsets.push_back(0);
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
max_seqlen = max(max_seqlen, seq_lengths[i]);
|
||||
seqlen += seq_lengths[i];
|
||||
seq_offsets.push_back(seqlen);
|
||||
};
|
||||
@@ -166,16 +170,16 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
assert(seq_lengths[i] - num_targets[i] >= min_full_seq_len);
|
||||
assert(seq_lengths[i] - num_targets[i] >= contextual_seq_len);
|
||||
assert(seq_lengths[i] - num_targets[i] >= min_full_attn_seqlen);
|
||||
assert(seq_lengths[i] - num_targets[i] >= contextual_seqlen);
|
||||
};
|
||||
}
|
||||
else
|
||||
{
|
||||
for(size_t i = 0; i < seq_lengths.size(); i++)
|
||||
{
|
||||
assert(seq_lengths[i] >= min_full_seq_len);
|
||||
assert(seq_lengths[i] >= contextual_seq_len);
|
||||
assert(seq_lengths[i] >= min_full_attn_seqlen);
|
||||
assert(seq_lengths[i] >= contextual_seqlen);
|
||||
};
|
||||
};
|
||||
}
|
||||
@@ -188,53 +192,212 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
assert(1 == num_targets.size());
|
||||
|
||||
assert(seqlen - num_targets[0] >= min_full_seq_len);
|
||||
assert(seqlen - num_targets[0] >= contextual_seq_len);
|
||||
assert(seqlen - num_targets[0] >= min_full_attn_seqlen);
|
||||
assert(seqlen - num_targets[0] >= contextual_seqlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(seqlen >= min_full_seq_len);
|
||||
assert(seqlen >= contextual_seq_len);
|
||||
assert(seqlen >= min_full_attn_seqlen);
|
||||
assert(seqlen >= contextual_seqlen);
|
||||
};
|
||||
};
|
||||
|
||||
int batches_for_alloc = is_jagged ? 1 : num_batch;
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> q_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_qk});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> k_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_qk});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_qk});
|
||||
ck_tile::HostTensor<InOutDataType> v_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_v});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
|
||||
ck_tile::HostTensor<InOutDataType> o_host_ref(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, nhead, hdim_v});
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
|
||||
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(q_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(k_host);
|
||||
ck_tile::FillNormalDistributionIntegerValue<InOutDataType>{-2.f, 2.f, seed}(v_host);
|
||||
|
||||
using GemmAccDataType = typename HSTUAttentionTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using SMComputeDataType = typename HSTUAttentionTypeConfig<InOutDataType>::SMComputeDataType;
|
||||
ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes());
|
||||
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
SMComputeDataType,
|
||||
kIsJagged,
|
||||
kUseCausal,
|
||||
kUseLocal>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
max_attn_len,
|
||||
contextual_seq_len,
|
||||
min_full_seq_len);
|
||||
});
|
||||
return 0;
|
||||
ck_tile::DeviceMem seq_offsets_dev(seq_offsets.size() * sizeof(int));
|
||||
ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int));
|
||||
|
||||
q_dev.ToDevice(q_host.data());
|
||||
k_dev.ToDevice(k_host.data());
|
||||
v_dev.ToDevice(v_host.data());
|
||||
|
||||
if(is_jagged)
|
||||
seq_offsets_dev.ToDevice(seq_offsets.data());
|
||||
if(!num_targets.empty())
|
||||
num_targets_dev.ToDevice(num_targets.data());
|
||||
|
||||
HstuAttentionFwdParams params;
|
||||
|
||||
if(is_jagged)
|
||||
{
|
||||
params.is_jagged = true;
|
||||
params.num_batch = num_batch;
|
||||
params.seq_offsets_ptr = seq_offsets_dev.GetDeviceBuffer();
|
||||
params.max_seqlen = max_seqlen;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr;
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = 1.0f / std::sqrt(params.hdim_qk);
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_causal = use_causal;
|
||||
params.window_size = window_size;
|
||||
params.contextual_seqlen = contextual_seqlen;
|
||||
params.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
params.p_drop = 0.0f; // dropout is not supported at present
|
||||
params.philox_seed = 0UL;
|
||||
params.philox_offset = 0UL;
|
||||
}
|
||||
else
|
||||
{
|
||||
params.is_jagged = false;
|
||||
params.num_batch = num_batch;
|
||||
params.seqlen = seqlen;
|
||||
params.q_ptr = q_dev.GetDeviceBuffer();
|
||||
params.k_ptr = k_dev.GetDeviceBuffer();
|
||||
params.v_ptr = v_dev.GetDeviceBuffer();
|
||||
params.bias_ptr = nullptr;
|
||||
params.o_ptr = o_dev.GetDeviceBuffer();
|
||||
params.hdim_qk = hdim_qk;
|
||||
params.hdim_v = hdim_v;
|
||||
params.num_head = num_head;
|
||||
params.scale_s = 1.0f / std::sqrt(params.hdim_qk);
|
||||
params.seq_stride_q = q_host.get_strides()[1];
|
||||
params.seq_stride_k = k_host.get_strides()[1];
|
||||
params.seq_stride_v = v_host.get_strides()[1];
|
||||
params.seq_stride_bias = 0;
|
||||
params.seq_stride_o = o_host_ref.get_strides()[1];
|
||||
params.nhead_stride_q = q_host.get_strides()[2];
|
||||
params.nhead_stride_k = k_host.get_strides()[2];
|
||||
params.nhead_stride_v = v_host.get_strides()[2];
|
||||
params.nhead_stride_bias = 0;
|
||||
params.nhead_stride_o = o_host_ref.get_strides()[2];
|
||||
params.batch_stride_q = q_host.get_strides()[0];
|
||||
params.batch_stride_k = k_host.get_strides()[0];
|
||||
params.batch_stride_v = v_host.get_strides()[0];
|
||||
params.batch_stride_bias = 0;
|
||||
params.batch_stride_o = o_host_ref.get_strides()[0];
|
||||
params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer();
|
||||
params.use_causal = use_causal;
|
||||
params.window_size = window_size;
|
||||
params.contextual_seqlen = contextual_seqlen;
|
||||
params.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
params.p_drop = 0.0f; // dropout is not supported at present
|
||||
params.philox_seed = 0UL;
|
||||
params.philox_offset = 0UL;
|
||||
};
|
||||
|
||||
hipStream_t stream;
|
||||
|
||||
HIP_CHECK_ERROR(hipStreamCreate(&stream));
|
||||
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_fp16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_fp16(params, stream);
|
||||
}
|
||||
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_bf16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_bf16(params, stream);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Other data type is not supported at present!");
|
||||
|
||||
bool res = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
using GemmAccDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType;
|
||||
using CompDataType = typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType;
|
||||
|
||||
BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] {
|
||||
ck_tile::reference_hstu_attention<InOutDataType,
|
||||
GemmAccDataType,
|
||||
CompDataType,
|
||||
kIsJagged,
|
||||
kUseCausal,
|
||||
kUseLocal>::Run(q_host,
|
||||
k_host,
|
||||
v_host,
|
||||
o_host_ref,
|
||||
num_batch,
|
||||
1.0f,
|
||||
seq_offsets,
|
||||
num_targets,
|
||||
window_size,
|
||||
contextual_seqlen,
|
||||
min_full_attn_seqlen);
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<InOutDataType> o_host(
|
||||
std::array<ck_tile::index_t, 4>{batches_for_alloc, seqlen, num_head, hdim_v});
|
||||
|
||||
o_dev.FromDevice(o_host.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<InOutDataType>();
|
||||
|
||||
res = ck_tile::check_err(
|
||||
o_host, o_host_ref, std::string("hstu_attention output error"), atol, rtol);
|
||||
};
|
||||
|
||||
if(measure_perf)
|
||||
{
|
||||
ck_tile::gpu_timer timer{};
|
||||
|
||||
timer.start(stream);
|
||||
for(int i = 0; i < 20; i++)
|
||||
{
|
||||
if constexpr(std::is_same<InOutDataType, ck_tile::fp16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_fp16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_fp16(params, stream);
|
||||
}
|
||||
else if constexpr(std::is_same<InOutDataType, ck_tile::bf16_t>::value)
|
||||
{
|
||||
if(is_jagged)
|
||||
hstu_attention_jagged_forward_bf16(params, stream);
|
||||
else
|
||||
hstu_attention_batched_forward_bf16(params, stream);
|
||||
}
|
||||
}
|
||||
timer.stop(stream);
|
||||
|
||||
auto ms = timer.duration() / 20.f;
|
||||
|
||||
std::cout << "Average execution time of the gather_attention operator is " << ms
|
||||
<< " milli-seconds" << std::endl;
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
|
||||
177
example/ck_tile/18_hstu_attention/generate_instances.py
Normal file
177
example/ck_tile/18_hstu_attention/generate_instances.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# noqa: C801
|
||||
# Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the BSD-style license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
HSTU_COPYRIGHT_HEADER = """
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `{file}`
|
||||
*/
|
||||
""".format(
|
||||
file=os.path.relpath(os.path.realpath(__file__), start=Path(__file__).parents[4])
|
||||
)
|
||||
|
||||
HSTU_FORWARD_INSTANCE_TEMPLATE_INC = """
|
||||
#include <ck_tile/core/numeric/{dtype_file}.hpp>
|
||||
#include \"hstu_attention_{mode}_forward_dispatch.hpp\"
|
||||
"""
|
||||
|
||||
HSTU_FORWARD_INSTANCE_TEMPLATE = """
|
||||
{extern}template void run_{mode}_forward_causal_local_bias_dropout_dispatch<
|
||||
{dtype},
|
||||
{has_causal},
|
||||
{has_local},
|
||||
{has_bias},
|
||||
{has_dropout},
|
||||
{max_k}>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
"""
|
||||
|
||||
HSTU_FORWARD_INSTANCE_FNAME = (
|
||||
"hstu_attention_{mode}_forward_{dtype_str}_{has_or_no_causal_str}_{has_or_no_local_str}_"
|
||||
"{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp"
|
||||
)
|
||||
|
||||
HSTU_INSTANCE_REF_FNAME = "hstu_attention_{mode}_{function}_{dtype}_instances_ref.hpp"
|
||||
|
||||
BOOL_MAP = {True: "true", False: "false"}
|
||||
|
||||
BOOL_MAP_CAUSAL = {
|
||||
True: "has_causal",
|
||||
False: "no_causal",
|
||||
}
|
||||
|
||||
BOOL_MAP_LOCAL = {
|
||||
True: "has_local",
|
||||
False: "no_local",
|
||||
}
|
||||
|
||||
BOOL_MAP_BIAS = {
|
||||
True: "has_bias",
|
||||
False: "no_bias",
|
||||
}
|
||||
|
||||
BOOL_MAP_DROPOUT = {
|
||||
True: "has_dropout",
|
||||
False: "no_dropout",
|
||||
}
|
||||
|
||||
INT_MAP_MAX_K = {hd: f"maxk_{hd}" for hd in [64, 128, 256]}
|
||||
|
||||
TYPE_CTYPE_MAP = {
|
||||
"fp16": "ck_tile::fp16_t",
|
||||
"bf16": "ck_tile::bf16_t",
|
||||
}
|
||||
|
||||
TYPE_FNAME_MAP = {
|
||||
"fp16": "bfloat16",
|
||||
"bf16": "half",
|
||||
}
|
||||
|
||||
MODE_NAME_MAP = {
|
||||
"batched": "Batched",
|
||||
"jagged": "Jagged",
|
||||
}
|
||||
|
||||
def create_forward_instances(instance_dir: Path, headdims: List) -> None:
|
||||
for mode in ["batched", "jagged"]:
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
for has_causal, has_local in zip([True, False],[True, False]):
|
||||
for has_bias in [True, False]:
|
||||
for has_dropout in [True, False]:
|
||||
for max_k in headdims:
|
||||
fname = HSTU_FORWARD_INSTANCE_FNAME.format(
|
||||
mode=mode,
|
||||
dtype_str=dtype,
|
||||
has_or_no_causal_str=BOOL_MAP_CAUSAL[has_causal],
|
||||
has_or_no_local_str=BOOL_MAP_CAUSAL[has_local],
|
||||
has_or_no_bias_str=BOOL_MAP_BIAS[has_bias],
|
||||
has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout],
|
||||
max_k_str=INT_MAP_MAX_K[max_k],
|
||||
)
|
||||
forward_instance_inc = (
|
||||
HSTU_FORWARD_INSTANCE_TEMPLATE_INC.format(
|
||||
mode=mode,
|
||||
dtype_file=TYPE_FNAME_MAP[dtype],
|
||||
)
|
||||
)
|
||||
forward_instance = HSTU_FORWARD_INSTANCE_TEMPLATE.format(
|
||||
extern="",
|
||||
mode=mode,
|
||||
dtype=TYPE_CTYPE_MAP[dtype],
|
||||
has_causal=BOOL_MAP[has_causal],
|
||||
has_local=BOOL_MAP[has_causal],
|
||||
has_bias=BOOL_MAP[has_bias],
|
||||
has_dropout=BOOL_MAP[has_dropout],
|
||||
max_k=max_k,
|
||||
cap_mode=MODE_NAME_MAP[mode],
|
||||
)
|
||||
(instance_dir / fname).write_text(
|
||||
HSTU_COPYRIGHT_HEADER
|
||||
+ forward_instance_inc
|
||||
+ forward_instance
|
||||
)
|
||||
|
||||
|
||||
def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None:
|
||||
for mode in ["batched", "jagged"]:
|
||||
for dtype in ["fp16", "bf16"]:
|
||||
ref_fname = HSTU_INSTANCE_REF_FNAME.format(
|
||||
mode=mode,
|
||||
function="forward",
|
||||
dtype=dtype,
|
||||
)
|
||||
ref_fname_path = instance_dir / ref_fname
|
||||
forward_instance_inc = HSTU_FORWARD_INSTANCE_TEMPLATE_INC.format(
|
||||
mode=mode,
|
||||
dtype_file=TYPE_FNAME_MAP[dtype],
|
||||
)
|
||||
with open(ref_fname_path, "a") as file:
|
||||
file.write(HSTU_COPYRIGHT_HEADER)
|
||||
file.write(forward_instance_inc)
|
||||
for max_k in headdims:
|
||||
for has_bias in [True, False]:
|
||||
for has_dropout in [True, False]:
|
||||
for has_causal, has_local in zip([True, False],[True, False]):
|
||||
forward_instance = (
|
||||
HSTU_FORWARD_INSTANCE_TEMPLATE.format(
|
||||
extern="extern ",
|
||||
mode=mode,
|
||||
dtype=TYPE_CTYPE_MAP[dtype],
|
||||
has_causal=BOOL_MAP[has_causal],
|
||||
has_local=BOOL_MAP[has_local],
|
||||
has_bias=BOOL_MAP[has_bias],
|
||||
has_dropout=BOOL_MAP[has_dropout],
|
||||
max_k=max_k,
|
||||
cap_mode=MODE_NAME_MAP[mode],
|
||||
)
|
||||
)
|
||||
file.write(forward_instance)
|
||||
|
||||
if __name__ == "__main__":
|
||||
headdims_fwd = [64, 128, 256]
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
output_dir = Path(this_dir) / "instances"
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# remove existing files in the directory
|
||||
files = os.listdir(output_dir)
|
||||
for ff in files:
|
||||
file_path = os.path.join(output_dir, ff)
|
||||
os.remove(file_path)
|
||||
|
||||
create_forward_instances(output_dir, headdims_fwd)
|
||||
create_forward_instances_ref(output_dir, headdims_fwd)
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_batched_forward_bf16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,155 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core/numeric/integer.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
struct batched_forward_causal_local_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
|
||||
using HstuMask = ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
false, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
HstuMask,
|
||||
HstuAttentionShape,
|
||||
HstuTraits>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_seqlen_k = !(param.seqlen % HstuAttentionShape::kN0 == 0);
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
BOOL_SWITCH_3(
|
||||
pad_seqlen_k,
|
||||
kPadSeqLenK,
|
||||
pad_headdim_qk,
|
||||
kPadHeadDimQK,
|
||||
pad_headdim_v,
|
||||
kPadHeadDimV,
|
||||
[&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
|
||||
using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seqlen,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
param.batch_stride_q,
|
||||
param.batch_stride_k,
|
||||
param.batch_stride_v,
|
||||
param.batch_stride_bias,
|
||||
param.batch_stride_o,
|
||||
param.num_targets_ptr,
|
||||
param.window_size,
|
||||
param.contextual_seqlen,
|
||||
param.min_full_attn_seqlen,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockSize.x, kBlockPerCu>(
|
||||
HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_batched_forward_causal_local_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
batched_forward_causal_local_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseLocal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
};
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_batched_forward_fp16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_batched_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
763
example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp
Normal file
763
example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp
Normal file
@@ -0,0 +1,763 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
|
||||
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
|
||||
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
|
||||
// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
|
||||
// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
|
||||
// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename HstuAttentionPipeline_, typename EpiloguePipeline_>
|
||||
struct HstuAttentionFwdKernel
|
||||
{
|
||||
using HstuAttentionPipeline = ck_tile::remove_cvref_t<HstuAttentionPipeline_>;
|
||||
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>;
|
||||
static constexpr ck_tile::index_t kBlockSize = HstuAttentionPipeline::kBlockSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = HstuAttentionPipeline::kBlockPerCu;
|
||||
static_assert(kBlockPerCu > 0);
|
||||
static constexpr ck_tile::index_t kBlockPerCuInput =
|
||||
HstuAttentionPipeline::Problem::kBlockPerCu;
|
||||
|
||||
using QKVDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::QKVDataType>;
|
||||
using BiasDataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::BiasDataType>;
|
||||
using ODataType = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::ODataType>;
|
||||
|
||||
using VLayout = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::VLayout>;
|
||||
|
||||
static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged;
|
||||
static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV;
|
||||
static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias;
|
||||
static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout;
|
||||
using HstuMask = ck_tile::remove_cvref_t<typename HstuAttentionPipeline::HstuMask>;
|
||||
static constexpr bool kHasMask = HstuMask::IsMasking;
|
||||
|
||||
template <ck_tile::index_t I> // to avoid duplicated base class problem, introduce an template
|
||||
// arg
|
||||
struct HstuAttentionFwdEmptyKargs
|
||||
{
|
||||
};
|
||||
|
||||
// kargs use aggregate initializer, so no constructor will provided
|
||||
// use inheritance to minimize karg size
|
||||
// user need to use MakeKargs() function to create kargs.
|
||||
struct HstuAttentionFwdCommonKargs
|
||||
{
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t seqlen;
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
const int32_t* num_targets_ptr;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonBiasKargs
|
||||
{
|
||||
const void* bias_ptr = nullptr;
|
||||
ck_tile::index_t seq_stride_bias = 0;
|
||||
ck_tile::index_t nhead_stride_bias = 0;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_bias = 0;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdMaskKargs
|
||||
{
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdDropoutSeedOffset
|
||||
{
|
||||
uint64_t drop_seed;
|
||||
uint64_t drop_offset;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdCommonDropoutKargs : HstuAttentionFwdDropoutSeedOffset
|
||||
{
|
||||
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
|
||||
{
|
||||
float p_undrop = 1.0 - p_drop;
|
||||
p_undrop_in_uint8_t =
|
||||
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
|
||||
rp_undrop = 1.0 / p_undrop;
|
||||
|
||||
this->drop_seed = seed;
|
||||
this->drop_offset = offset;
|
||||
}
|
||||
|
||||
float rp_undrop = 1;
|
||||
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdBatchModeKargs
|
||||
: HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdBatchModeBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, HstuAttentionFwdMaskKargs, HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
struct HstuAttentionFwdJaggModeKargs
|
||||
: HstuAttentionFwdCommonKargs,
|
||||
std::conditional_t<kHasBias,
|
||||
HstuAttentionFwdCommonBiasKargs,
|
||||
HstuAttentionFwdEmptyKargs<0>>,
|
||||
std::conditional_t<kHasMask, HstuAttentionFwdMaskKargs, HstuAttentionFwdEmptyKargs<1>>,
|
||||
std::conditional_t<kHasDropout,
|
||||
HstuAttentionFwdCommonDropoutKargs,
|
||||
HstuAttentionFwdEmptyKargs<2>>
|
||||
{
|
||||
const int32_t* seq_offsets_ptr;
|
||||
};
|
||||
|
||||
using Kargs = std::
|
||||
conditional_t<kIsJagged, HstuAttentionFwdJaggModeKargs, HstuAttentionFwdBatchModeKargs>;
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr)}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.seq_stride_bias = seq_stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
kargs.batch_stride_bias = batch_stride_bias;
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.contextual_seqlen = contextual_seqlen;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto seed = std::get<0>(drop_seed_offset);
|
||||
auto offset = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = !kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t seqlen,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_q,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_bias,
|
||||
ck_tile::index_t batch_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
o_ptr,
|
||||
seqlen,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_bias,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_o,
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v,
|
||||
batch_stride_bias,
|
||||
batch_stride_o,
|
||||
num_targets_ptr,
|
||||
window_size,
|
||||
contextual_seqlen,
|
||||
min_full_attn_seqlen,
|
||||
p_drop,
|
||||
std::make_pair(philox_seed, philox_offset));
|
||||
}
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargsImpl(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
const std::pair<uint64_t, uint64_t>& drop_seed_offset)
|
||||
{
|
||||
Kargs kargs{{q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
o_ptr,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_o,
|
||||
reinterpret_cast<const int32_t*>(num_targets_ptr)}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
{}, // placeholder for mask
|
||||
{}, // placeholder for dropout
|
||||
reinterpret_cast<const int32_t*>(seq_offsets_ptr)};
|
||||
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
kargs.bias_ptr = bias_ptr;
|
||||
kargs.seq_stride_bias = seq_stride_bias;
|
||||
kargs.nhead_stride_bias = nhead_stride_bias;
|
||||
}
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
kargs.window_size = window_size;
|
||||
kargs.contextual_seqlen = contextual_seqlen;
|
||||
kargs.min_full_attn_seqlen = min_full_attn_seqlen;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto seed = std::get<0>(drop_seed_offset);
|
||||
auto offset = std::get<1>(drop_seed_offset);
|
||||
kargs.init_dropout(p_drop, seed, offset);
|
||||
}
|
||||
|
||||
return kargs;
|
||||
}
|
||||
|
||||
template <bool Cond = kIsJagged>
|
||||
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
|
||||
MakeKargs(const void* q_ptr,
|
||||
const void* k_ptr,
|
||||
const void* v_ptr,
|
||||
const void* bias_ptr,
|
||||
void* o_ptr,
|
||||
const void* seq_offsets_ptr,
|
||||
ck_tile::index_t hdim_qk,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head,
|
||||
float scale_s,
|
||||
ck_tile::index_t seq_stride_q,
|
||||
ck_tile::index_t seq_stride_k,
|
||||
ck_tile::index_t seq_stride_v,
|
||||
ck_tile::index_t seq_stride_bias,
|
||||
ck_tile::index_t seq_stride_o,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
const void* num_targets_ptr,
|
||||
ck_tile::index_t window_size,
|
||||
ck_tile::index_t contextual_seqlen,
|
||||
ck_tile::index_t min_full_attn_seqlen,
|
||||
float p_drop,
|
||||
uint64_t philox_seed,
|
||||
uint64_t philox_offset)
|
||||
{
|
||||
return MakeKargsImpl(q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
o_ptr,
|
||||
seq_offsets_ptr,
|
||||
hdim_qk,
|
||||
hdim_v,
|
||||
num_head,
|
||||
scale_s,
|
||||
seq_stride_q,
|
||||
seq_stride_k,
|
||||
seq_stride_v,
|
||||
seq_stride_bias,
|
||||
seq_stride_o,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
nhead_stride_bias,
|
||||
nhead_stride_o,
|
||||
num_targets_ptr,
|
||||
window_size,
|
||||
contextual_seqlen,
|
||||
min_full_attn_seqlen,
|
||||
p_drop,
|
||||
std::make_pair(philox_seed, philox_offset));
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
const index_t i_nhead = blockIdx.y;
|
||||
const index_t i_batch = blockIdx.z;
|
||||
|
||||
const auto f = [](index_t dividend, index_t divisor) {
|
||||
index_t quotient = dividend / divisor;
|
||||
index_t modulus = dividend - quotient * divisor;
|
||||
return ck_tile::make_tuple(quotient, modulus);
|
||||
};
|
||||
|
||||
const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
|
||||
|
||||
return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
|
||||
// divide problem
|
||||
const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
|
||||
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1);
|
||||
|
||||
long_index_t batch_offset_q = 0;
|
||||
long_index_t batch_offset_k = 0;
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
if constexpr(kIsJagged)
|
||||
{
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seq_offsets_ptr[i_batch];
|
||||
const long_index_t key_start = query_start;
|
||||
|
||||
batch_offset_q = query_start * kargs.seq_stride_q;
|
||||
batch_offset_k = key_start * kargs.seq_stride_k;
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.seq_stride_v;
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_v = key_start;
|
||||
}
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
batch_offset_bias = query_start * kargs.seq_stride_bias;
|
||||
}
|
||||
batch_offset_o = query_start * kargs.seq_stride_o;
|
||||
|
||||
kargs.seqlen = kargs.seq_offsets_ptr[1] - kargs.seq_offsets_ptr[0];
|
||||
|
||||
// # of required blocks is different in each groups, terminate unnecessary blocks
|
||||
// earlier
|
||||
if(kargs.seqlen <= i_m0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
|
||||
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
|
||||
}
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
int max_uih_len = kargs.seqlen;
|
||||
|
||||
if constexpr(kHasMask)
|
||||
{
|
||||
if(kargs.contextual_seqlen > 0)
|
||||
max_uih_len -= kargs.contextual_seqlen - 1;
|
||||
};
|
||||
|
||||
if(kargs.num_targets_ptr != nullptr)
|
||||
{
|
||||
if constexpr(kIsJagged)
|
||||
max_uih_len -= kargs.num_targets_ptr[i_batch];
|
||||
else
|
||||
max_uih_len -= kargs.num_targets_ptr[0];
|
||||
};
|
||||
|
||||
HstuMask mask = [&]() {
|
||||
if constexpr(kHasMask)
|
||||
return HstuMask{kargs.window_size,
|
||||
kargs.contextual_seqlen,
|
||||
kargs.min_full_attn_seqlen,
|
||||
max_uih_len};
|
||||
else
|
||||
return HstuMask{0, 0, 0, 0};
|
||||
}();
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QKVDataType* q_ptr = reinterpret_cast<const QKVDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const QKVDataType* k_ptr = reinterpret_cast<const QKVDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const QKVDataType* v_ptr = reinterpret_cast<const QKVDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
|
||||
// Q/K/V DRAM and DRAM window
|
||||
const auto q_dram = [&]() {
|
||||
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_qk),
|
||||
make_tuple(kargs.seq_stride_q, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentQ>{},
|
||||
number<1>{});
|
||||
if constexpr(HstuAttentionPipeline::kQLoadOnce)
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQK>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return pad_tensor_view(q_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kK0>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQK>{});
|
||||
}
|
||||
}();
|
||||
const auto k_dram = [&]() {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
k_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_qk),
|
||||
make_tuple(kargs.seq_stride_k, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentK>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(k_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{},
|
||||
number<HstuAttentionPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQK>{});
|
||||
}();
|
||||
const auto v_dram = [&]() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
const auto v_dram_transposed =
|
||||
transform_tensor_view(v_dram_naive,
|
||||
make_tuple(make_pass_through_transform(kargs.hdim_v),
|
||||
make_pass_through_transform(kargs.seqlen)),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return pad_tensor_view(v_dram_transposed,
|
||||
make_tuple(number<HstuAttentionPipeline::kN1>{},
|
||||
number<HstuAttentionPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
v_ptr,
|
||||
make_tuple(kargs.hdim_v, kargs.seqlen),
|
||||
make_tuple(kargs.seq_stride_v, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentV>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(v_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kN1>{},
|
||||
number<HstuAttentionPipeline::kK1>{}),
|
||||
sequence<kPadHeadDimV, false>{});
|
||||
}
|
||||
}();
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
[&]() {
|
||||
if constexpr(HstuAttentionPipeline::kQLoadOnce)
|
||||
return make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kSubQKHeaddim>{});
|
||||
else
|
||||
return make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kK0>{});
|
||||
}(),
|
||||
{i_m0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kN0>{}, number<HstuAttentionPipeline::kK0>{}),
|
||||
{0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kN1>{}, number<HstuAttentionPipeline::kK1>{}),
|
||||
{i_n1, 0});
|
||||
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
|
||||
/// following copy capture of the 'i_nhead' if in C++20
|
||||
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
|
||||
constexpr auto bias_dram_window_lengths = make_tuple(
|
||||
number<HstuAttentionPipeline::kM0>{}, number<HstuAttentionPipeline::kN0>{});
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
const BiasDataType* bias_ptr =
|
||||
reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
|
||||
batch_offset_bias;
|
||||
|
||||
const auto bias_dram = [&]() {
|
||||
const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
bias_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.seqlen),
|
||||
make_tuple(kargs.seq_stride_bias, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentBias>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(bias_dram_naive,
|
||||
bias_dram_window_lengths,
|
||||
sequence<kPadSeqLenQ, kPadSeqLenK>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_null_tile_window(bias_dram_window_lengths);
|
||||
}
|
||||
}();
|
||||
|
||||
auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
return BlockDropout{i_batch_,
|
||||
i_nhead_,
|
||||
kargs.num_head,
|
||||
kargs.drop_seed,
|
||||
kargs.drop_offset,
|
||||
kargs.rp_undrop,
|
||||
kargs.p_undrop_in_uint8_t,
|
||||
false};
|
||||
}
|
||||
else
|
||||
{
|
||||
return NullBlockDropout{};
|
||||
};
|
||||
}();
|
||||
|
||||
auto o_acc_tile = [&]() {
|
||||
return HstuAttentionPipeline{}(q_dram_window,
|
||||
k_dram_window,
|
||||
v_dram_window,
|
||||
bias_dram_window,
|
||||
mask,
|
||||
kargs.scale_s,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}();
|
||||
|
||||
// O DRAM and O DRAM window
|
||||
auto o_dram = [&]() {
|
||||
const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.seqlen, kargs.hdim_v),
|
||||
make_tuple(kargs.seq_stride_o, 1),
|
||||
number<HstuAttentionPipeline::kAlignmentO>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(o_dram_naive,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{},
|
||||
number<HstuAttentionPipeline::kN1>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto o_dram_window = make_tile_window(
|
||||
o_dram,
|
||||
make_tuple(number<HstuAttentionPipeline::kM0>{}, number<HstuAttentionPipeline::kN1>{}),
|
||||
{i_m0, i_n1});
|
||||
|
||||
EpiloguePipeline{}(o_dram_window, o_acc_tile);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,548 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
|
||||
|
||||
#include "hstu_attention_fwd_pipeline_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = HstuAttentionFwdPipelineQRKSVSDefaultPolicy>
|
||||
struct HstuAttentionFwdPipelineQRKSVS
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using QKVDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using GemmAccDataType = remove_cvref_t<typename Problem::GemmAccDataType>;
|
||||
using CompDataType = remove_cvref_t<typename Problem::CompDataType>;
|
||||
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
|
||||
using PDataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::InOutDataType>;
|
||||
using HstuMask = remove_cvref_t<typename Problem::HstuMask>;
|
||||
|
||||
using HstuAttentionTileShape = remove_cvref_t<typename Problem::HstuAttentionTileShape>;
|
||||
using VLayout = remove_cvref_t<typename HstuAttentionTileShape::VLayout>;
|
||||
static constexpr bool kQLoadOnce = true;
|
||||
static_assert(kQLoadOnce == Policy::QLoadOnce);
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = HstuAttentionTileShape::kM0;
|
||||
static constexpr index_t kN0 = HstuAttentionTileShape::kN0;
|
||||
static constexpr index_t kK0 = HstuAttentionTileShape::kK0;
|
||||
static constexpr index_t kN1 = HstuAttentionTileShape::kN1;
|
||||
static constexpr index_t kK1 = HstuAttentionTileShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = HstuAttentionTileShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = HstuAttentionTileShape::kSubQKHeaddim;
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
static constexpr bool kIsJagged = Problem::kIsJagged;
|
||||
static constexpr auto kHasBias = Problem::kHasBias;
|
||||
static constexpr bool kHasDropout = Problem::kHasDropout;
|
||||
|
||||
static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK;
|
||||
static constexpr bool kPadHeadDimV =
|
||||
(kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV;
|
||||
|
||||
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
|
||||
// ... together with tensor distribution. tensor dist should able to overwrite this
|
||||
static constexpr index_t kAlignmentQ =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ<Problem>();
|
||||
static constexpr index_t kAlignmentK =
|
||||
kPadHeadDimQK ? 1 : Policy::template GetAlignmentK<Problem>();
|
||||
static constexpr index_t kAlignmentV = []() {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
return Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
else
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
static constexpr index_t kBlockPerCu = []() {
|
||||
if constexpr(Problem::Traits::kBlockPerCu != -1)
|
||||
return Problem::Traits::kBlockPerCu;
|
||||
else
|
||||
{
|
||||
if constexpr(kQKHeaddim == 32)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 64)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128)
|
||||
{
|
||||
if constexpr(kHasBias)
|
||||
return 1;
|
||||
else
|
||||
return 2;
|
||||
}
|
||||
else if constexpr(kQKHeaddim == 256)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 1;
|
||||
};
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr const char* name = "qr_hstu";
|
||||
|
||||
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp,
|
||||
typename QElementFunction,
|
||||
typename KElementFunction,
|
||||
typename VElementFunction,
|
||||
typename BiasElementFunction,
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile
|
||||
const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const VElementFunction& v_element_func,
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
const BiasElementFunction& bias_element_func,
|
||||
const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
ignore = q_element_func;
|
||||
ignore = k_element_func;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType,
|
||||
remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<QKVDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
|
||||
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = number<0>{};
|
||||
|
||||
constexpr index_t k0_loops = kQKHeaddim / kK0;
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
static_assert(2 <= k0_loops);
|
||||
static_assert(2 <= k1_loops);
|
||||
|
||||
constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers<Problem>();
|
||||
constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers<Problem>();
|
||||
constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV<Problem>();
|
||||
|
||||
static_assert(NumKLdsBuffers >= 2);
|
||||
|
||||
auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
q_dram_block_window_tmp.get_window_lengths(),
|
||||
q_dram_block_window_tmp.get_window_origin(),
|
||||
Policy::template MakeQRegTileDistribution<Problem>());
|
||||
|
||||
const auto q_origin = q_dram_window.get_window_origin();
|
||||
const auto [seqlen_k_start, seqlen_k_end] =
|
||||
mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
|
||||
|
||||
auto k_dram_block_window =
|
||||
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
k_dram_block_window_tmp.get_window_lengths(),
|
||||
{seqlen_k_start, 0});
|
||||
|
||||
auto k_dram_window =
|
||||
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
|
||||
k_dram_block_window.get_window_lengths(),
|
||||
k_dram_block_window.get_window_origin(),
|
||||
Policy::template MakeKDramTileDistribution<Problem>());
|
||||
|
||||
auto k_tile = load_tile(k_dram_window);
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
auto q_tile = load_tile(q_dram_window);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// K tile in LDS
|
||||
QKVDataType* k_lds_ptr = static_cast<QKVDataType*>(smem_ptr);
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsBlockDescriptor<Problem>());
|
||||
auto k_lds_window = make_tile_window(
|
||||
k_lds, Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using k_lds_window_type =
|
||||
decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence<kN0, kK0>{}));
|
||||
|
||||
statically_indexed_array<k_lds_window_type, NumKLdsBuffers> k_lds_windows;
|
||||
|
||||
static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
k_lds_windows[i_buf] = get_slice_tile(
|
||||
k_lds_window, sequence<i_buf * kN0, 0>{}, sequence<(i_buf + 1) * kN0, kK0>{});
|
||||
});
|
||||
|
||||
auto v_dram_window =
|
||||
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
v_dram_block_window_tmp.get_window_lengths(),
|
||||
{0, seqlen_k_start}, // TODO: hdim split?
|
||||
Policy::template MakeVDramTileDistribution<Problem>());
|
||||
// V tile in LDS
|
||||
auto v_lds = make_tensor_view<address_space_enum::lds>(
|
||||
reinterpret_cast<QKVDataType*>(static_cast<char*>(smem_ptr) +
|
||||
Policy::template GetExclusiveKLdsBytes<Problem>()),
|
||||
Policy::template MakeVLdsBlockDescriptor<Problem>());
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
|
||||
|
||||
using v_tile_type = decltype(load_tile(v_dram_window));
|
||||
|
||||
statically_indexed_array<v_tile_type, NumPrefetchV> v_tiles;
|
||||
|
||||
using v_lds_window_type =
|
||||
decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence<kN1, kK1>{}));
|
||||
|
||||
statically_indexed_array<v_lds_window_type, NumVLdsBuffers> v_lds_windows;
|
||||
|
||||
static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) {
|
||||
v_lds_windows[i_buf] = get_slice_tile(
|
||||
v_lds_window, sequence<i_buf * kN1, 0>{}, sequence<(i_buf + 1) * kN1, kK1>{});
|
||||
});
|
||||
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
|
||||
|
||||
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
|
||||
auto s_acc = SaccBlockTileType{};
|
||||
|
||||
// reduction function for softmax
|
||||
const auto f_silu = [](CompDataType x) {
|
||||
auto one = ck_tile::type_convert<CompDataType>(1.0f);
|
||||
|
||||
auto sigmod_val = one / (one + exp(-x));
|
||||
|
||||
return sigmod_val * x;
|
||||
};
|
||||
|
||||
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
|
||||
|
||||
// init Oacc, M, L
|
||||
auto o_acc = OaccBlockTileType{};
|
||||
|
||||
clear_tile(o_acc);
|
||||
|
||||
const auto num_loops = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
|
||||
// check early exit if no work to do
|
||||
if constexpr(HstuMask::IsMasking || kPadSeqLenK)
|
||||
{
|
||||
if(num_loops <= 0)
|
||||
{
|
||||
return o_acc;
|
||||
}
|
||||
}
|
||||
|
||||
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
|
||||
auto bias_dram_window =
|
||||
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
bias_dram_block_window_tmp.get_window_lengths(),
|
||||
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
|
||||
Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
|
||||
|
||||
auto null_randval_window = [&]() {
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
const auto null_randval_dram = [&]() {
|
||||
const auto null_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<uint8_t*>(nullptr),
|
||||
make_tuple(1, 1),
|
||||
make_tuple(1, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(null_dram_naive,
|
||||
make_tuple(number<1>{}, number<1>{}),
|
||||
sequence<true, true>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0});
|
||||
}
|
||||
else
|
||||
return make_null_tile_window(make_tuple(number<1>{}, number<1>{}));
|
||||
}();
|
||||
|
||||
q_tile = tile_elementwise_in(q_element_func, q_tile);
|
||||
|
||||
index_t i_loop = 0;
|
||||
|
||||
do
|
||||
{
|
||||
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
|
||||
store_tile(k_lds_windows[number<i_k0 % NumKLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
if constexpr(i_k0 == 0)
|
||||
clear_tile(s_acc);
|
||||
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
k_tile = load_tile(k_dram_window);
|
||||
if constexpr(i_k0 < k0_loops - 2)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
block_sync_lds();
|
||||
// execute current unroll of gemm_0
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(
|
||||
q_tile, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
|
||||
k_lds_windows[number<i_k0 % NumKLdsBuffers>{}]);
|
||||
});
|
||||
|
||||
store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}],
|
||||
tile_elementwise_in(k_element_func, k_tile));
|
||||
|
||||
// prefetch first v_tile
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_0(s_acc,
|
||||
get_slice_tile(q_tile,
|
||||
sequence<0, (k0_loops - 1) * kK0>{},
|
||||
sequence<kM0, k0_loops * kK0>{}),
|
||||
k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
|
||||
|
||||
static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) {
|
||||
v_tiles[i_buf] = load_tile(v_dram_window);
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
|
||||
// STAGE 2, scale_s, add bias, mask, siLU
|
||||
if constexpr(kHasBias)
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
tile_elementwise_inout(
|
||||
[&](auto& x, const auto& y) {
|
||||
x += type_convert<GemmAccDataType>(bias_element_func(y));
|
||||
},
|
||||
s_acc,
|
||||
bias_tile);
|
||||
}
|
||||
else
|
||||
{
|
||||
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
|
||||
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
|
||||
}
|
||||
|
||||
move_tile_window(bias_dram_window, {0, kN0});
|
||||
if constexpr(HstuMask::IsMasking)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
}
|
||||
else if constexpr(kPadSeqLenK)
|
||||
{
|
||||
const auto k_origin = k_dram_block_window.get_window_origin();
|
||||
set_tile_if(s_acc, -numeric<CompDataType>::infinity(), [&](auto tile_idx) {
|
||||
if(i_loop < num_loops)
|
||||
return false;
|
||||
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
|
||||
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
|
||||
return mask.IsTokenPairInsideMask(row, col);
|
||||
});
|
||||
};
|
||||
|
||||
auto s = cast_tile<CompDataType>(s_acc); // S{j}
|
||||
|
||||
s = tile_elementwise_in(f_silu, s);
|
||||
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
auto randval_lds_ptr =
|
||||
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>();
|
||||
|
||||
dropout.template Run<decltype(gemm_0), CompDataType, uint8_t>(
|
||||
randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window);
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0x7f);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
|
||||
|
||||
store_tile(
|
||||
v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_windows[I0],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
const auto p = cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, s));
|
||||
|
||||
// STAGE 3, KV gemm
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
v_tiles[I0] = load_tile(v_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp, v_tiles[I0]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp));
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_tiles[I0]));
|
||||
}
|
||||
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
else // NumVLdsBuffers == 3 or 2
|
||||
{
|
||||
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
v_tiles[number<i_k1 % NumPrefetchV>{}] = load_tile(v_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(
|
||||
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
|
||||
v_lds_windows[number<i_k1 % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(std::is_same_v<VLayout,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
auto v_shuffle_tmp = make_static_distributed_tensor<QKVDataType>(
|
||||
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
|
||||
shuffle_tile(v_shuffle_tmp,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]);
|
||||
store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func, v_shuffle_tmp));
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(
|
||||
v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}],
|
||||
tile_elementwise_in(v_element_func,
|
||||
v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]));
|
||||
}
|
||||
|
||||
if constexpr(i_k1 < k1_loops - NumPrefetchV)
|
||||
move_tile_window(v_dram_window, {0, kK1});
|
||||
});
|
||||
}
|
||||
}
|
||||
// move K tile windows
|
||||
move_tile_window(k_dram_block_window, {kN0, 0});
|
||||
|
||||
block_sync_lds();
|
||||
gemm_1(o_acc,
|
||||
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
|
||||
v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]);
|
||||
|
||||
if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer<Problem>())
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
};
|
||||
} while(++i_loop < num_loops);
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
|
||||
return o_acc;
|
||||
}
|
||||
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp,
|
||||
typename BiasDramBlockWindowTmp>
|
||||
CK_TILE_HOST_DEVICE auto
|
||||
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
|
||||
HstuMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr,
|
||||
DropoutType& dropout) const
|
||||
{
|
||||
return operator()(q_dram_block_window_tmp,
|
||||
identity{},
|
||||
k_dram_block_window_tmp,
|
||||
identity{},
|
||||
v_dram_block_window_tmp,
|
||||
identity{},
|
||||
bias_dram_block_window_tmp,
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
identity{},
|
||||
mask,
|
||||
scale_s,
|
||||
smem_ptr,
|
||||
dropout);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,370 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ -1,
|
||||
/* NumPrefetchV = */ 2>
|
||||
{
|
||||
static constexpr index_t NumPrefetchV = 2;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers()
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto GetNumPrefetchV()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t k1_loops = kN0 / kK1;
|
||||
|
||||
return min(NumPrefetchV, k1_loops);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers()
|
||||
{
|
||||
return 2;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
|
||||
return BlockGemm::template MakeABlockTileDistribution<
|
||||
Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
return 8 / sizeof(QKVDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
constexpr index_t kKVector = GetAlignmentK<Problem>();
|
||||
|
||||
static_assert(kKVector % kKPack == 0);
|
||||
|
||||
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumKLdsBuffers>{},
|
||||
number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kNPerBlock>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock * kNPerBlock + kKPerBlock * kKPack / kKVector>{},
|
||||
number<kNPerBlock * kKVector + kKPack>{},
|
||||
number<kNPerBlock * kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
|
||||
k_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<NumKLdsBuffers>{}, number<kNPerBlock>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKVector>{},
|
||||
number<kKVector / kKPack>{},
|
||||
number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t NThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<NPerThread, NThreadPerWarp, NumWarps>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<2>, sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
using QKVDataType = remove_cvref_t<typename Problem::QKVDataType>;
|
||||
|
||||
constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t Banks = 32; // TODO: need change based on arch
|
||||
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType);
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(PixelsPerRow % kKPack == 0);
|
||||
constexpr index_t NPerRow = PixelsPerRow / kKPack;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
static_assert(kNPerBlock % NPerRow == 0);
|
||||
static_assert(kKPerBlock % kKPack == 0);
|
||||
|
||||
constexpr index_t VSingleSmemElementSpaceSize =
|
||||
(kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack);
|
||||
|
||||
constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<NumVLdsBuffers>{},
|
||||
number<kKPerBlock / kKPack>{},
|
||||
number<kNPerBlock / NPerRow>{},
|
||||
number<NPerRow>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<VSingleSmemElementSpaceSize>{},
|
||||
number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
|
||||
number<PixelsPerRow + kKPack>{},
|
||||
number<kKPack>{},
|
||||
number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc = transform_tensor_descriptor(
|
||||
v_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(
|
||||
number<NumVLdsBuffers>{}, number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t N0 = kNPerBlock / N1; // P
|
||||
|
||||
constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(ElemPerThread % N1 == 0);
|
||||
constexpr index_t K3 = ElemPerThread / N1;
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3;
|
||||
if constexpr(get_warp_size() % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size();
|
||||
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = GetAlignmentV<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
static_assert(N0 != 0);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
Problem::kNumGemm0Warps * get_warp_size(),
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
constexpr auto warp_gemm = []() {
|
||||
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
|
||||
|
||||
if constexpr(std::is_same_v<typename Problem::QKVDataType, half_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaF16F16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QKVDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
if constexpr(WarpGemmM == 32)
|
||||
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
|
||||
else if constexpr(WarpGemmM == 16)
|
||||
return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{};
|
||||
else // WarpGemmM == 4
|
||||
return WarpGemmMfmaBf16Bf16F32M4N64K16{};
|
||||
}
|
||||
else if constexpr(std::is_same_v<typename Problem::QKVDataType, fp8_t> &&
|
||||
std::is_same_v<typename Problem::GemmAccDataType, float>)
|
||||
{
|
||||
static_assert(WarpGemmM == 32);
|
||||
|
||||
// TODO: hard coded here. Otherwise, it may incorrect result
|
||||
constexpr index_t swizzle_factor = 4;
|
||||
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
|
||||
swizzle_factor>{};
|
||||
} // TODO - bf8_t
|
||||
}();
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBSmemCRegV2CustomPolicy<typename Problem::QKVDataType,
|
||||
typename Problem::QKVDataType,
|
||||
typename Problem::GemmAccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
if constexpr(1 < Problem::kNumGemm0Warps)
|
||||
return BlockGemmARegBSmemCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
else
|
||||
return BlockGemmARegBSmemCRegOneWarpV1<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
// leave some exclusive space so that the second v_lds buffer will nenver overlap with the first
|
||||
// k_lds bufffer
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes()
|
||||
{
|
||||
constexpr index_t single_k_lds_buffer_size =
|
||||
GetSmemSizeK<Problem>() / GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t single_v_lds_buffer_size =
|
||||
GetSmemSizeV<Problem>() / GetNumVLdsBuffers<Problem>();
|
||||
|
||||
if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size)
|
||||
return 0;
|
||||
else
|
||||
return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64);
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer()
|
||||
{
|
||||
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
|
||||
|
||||
constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1;
|
||||
constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers<Problem>();
|
||||
constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers<Problem>();
|
||||
|
||||
constexpr index_t last_v_lds_buffer_offset =
|
||||
MakeVLdsBlockDescriptor<Problem>().get_element_space_size() / num_v_lds_buffers *
|
||||
((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t first_k_lds_buffer_size =
|
||||
MakeKLdsBlockDescriptor<Problem>().get_element_space_size() / num_k_lds_buffers *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
|
||||
return GetExclusiveKLdsBytes<Problem>() + last_v_lds_buffer_offset <
|
||||
first_k_lds_buffer_size;
|
||||
};
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
|
||||
{
|
||||
return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
|
||||
{
|
||||
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::QKVDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
// assume V can reuse the other shared memory by K except the first
|
||||
// assume Dropout can reuse the shared memory by V
|
||||
return GetExclusiveKLdsBytes<Problem>() +
|
||||
max(GetSmemSizeK<Problem>() - GetExclusiveKLdsBytes<Problem>(),
|
||||
max(GetSmemSizeV<Problem>(), GetSmemSizeDropout<Problem>(0)));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,98 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdBlockTile;
|
||||
|
||||
// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0)
|
||||
//
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<32>
|
||||
{
|
||||
using type = ck_tile::sequence<64, 64, 16, 32, 32, 32>;
|
||||
using gemm0_warps = ck_tile::sequence<2, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<2, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<64>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<128>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 128, 32, 128, 32, 128>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdBlockTile<256>
|
||||
{
|
||||
using type = ck_tile::sequence<128, 128, 32, 256, 32, 256>;
|
||||
using gemm0_warps = ck_tile::sequence<4, 1, 1>;
|
||||
using gemm1_warps = ck_tile::sequence<4, 1, 1>;
|
||||
};
|
||||
|
||||
using HstuAttentionFwdWarpTile1 = ck_tile::sequence<32, 32, 16>;
|
||||
|
||||
template <ck_tile::index_t MaxK>
|
||||
struct HstuAttentionFwdShape;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<32>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<32>::type,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<32>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<64>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<64>::type,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<64>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<128>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<128>::type,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<128>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdShape<256>
|
||||
{
|
||||
using Type = ck_tile::TileFmhaShape<typename HstuAttentionFwdBlockTile<256>::type,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm0_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
typename HstuAttentionFwdBlockTile<256>::gemm1_warps,
|
||||
HstuAttentionFwdWarpTile1,
|
||||
IsVLayoutRowMajor>;
|
||||
};
|
||||
@@ -0,0 +1,34 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
template <typename InOutDataType>
|
||||
struct HstuAttentionFwdTypeConfig;
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTypeConfig<ck_tile::fp16_t>
|
||||
{
|
||||
using BiasDataType = ck_tile::fp16_t;
|
||||
using GemmAccDataType = float;
|
||||
using CompDataType = float; // data type for non-linear calculation
|
||||
using OaccDataType = GemmAccDataType;
|
||||
using ODataType = ck_tile::fp16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HstuAttentionFwdTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using BiasDataType = ck_tile::bf16_t;
|
||||
using GemmAccDataType = float;
|
||||
using CompDataType = float; // data type for non-linear calculation
|
||||
using OaccDataType = GemmAccDataType;
|
||||
using ODataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
static constexpr bool IsVLayoutRowMajor = true;
|
||||
@@ -0,0 +1,30 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#define HDIM_SWITCH(HDIM_1, HDIM_2, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if(HDIM_1 <= 64 && HDIM_2 <= 64) \
|
||||
{ \
|
||||
constexpr ck_tile::index_t CONST_NAME = 64; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
else if(HDIM_1 <= 128 && HDIM_2 <= 128) \
|
||||
{ \
|
||||
constexpr ck_tile::index_t CONST_NAME = 128; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
else if(HDIM_1 <= 256 && HDIM_2 <= 256) \
|
||||
{ \
|
||||
constexpr ck_tile::index_t CONST_NAME = 256; \
|
||||
__VA_ARGS__(); \
|
||||
} \
|
||||
else \
|
||||
{ \
|
||||
throw std::runtime_error("Head-dim sizes not supported!"); \
|
||||
} \
|
||||
}()
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::bf16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
@@ -0,0 +1,144 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core/numeric/integer.hpp>
|
||||
#include <ck_tile/host/kernel_launch.hpp>
|
||||
#include <ck_tile/host/stream_config.hpp>
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <ck_tile/ops/fmha.hpp>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_fwd_type_config.hpp"
|
||||
#include "hstu_attention_fwd_setting.hpp"
|
||||
#include "hstu_attention_params.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_block_masking.hpp"
|
||||
#include "hstu_attention_pipeline_problem.hpp"
|
||||
#include "hstu_attention_traits.hpp"
|
||||
#include "hstu_attention_fwd_pipeline.hpp"
|
||||
#include "hstu_attention_fwd_kernel.hpp"
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
struct jagged_forward_causal_local_bias_dropout_dispatch
|
||||
{
|
||||
using HstuAttentionShape = typename HstuAttentionFwdShape<MaxK>::Type;
|
||||
using HstuMask = ck_tile::HstuBlockMasking<kUseCausal, kUseLocal>;
|
||||
|
||||
template <typename HstuTraits>
|
||||
using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem<
|
||||
InOutDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::GemmAccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::CompDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::BiasDataType,
|
||||
true, // kIsJagged
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
HstuMask,
|
||||
HstuAttentionShape,
|
||||
HstuTraits>;
|
||||
|
||||
static void Run(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
constexpr ck_tile::index_t occupancy = -1;
|
||||
|
||||
const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0);
|
||||
const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0);
|
||||
|
||||
// no need to check seqlen_q since it is not used as fastest dim,
|
||||
// buffer_load_dwordxx/buffer_store_dwordxx can handle oob access
|
||||
constexpr bool kPadSeqLenQ = false;
|
||||
|
||||
constexpr bool kPadSeqLenK = true;
|
||||
|
||||
BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] {
|
||||
using HstuTraits = ck_tile::HstuAttentionFwdTraits<kPadSeqLenQ,
|
||||
kPadSeqLenK,
|
||||
kPadHeadDimQK,
|
||||
kPadHeadDimV,
|
||||
occupancy>;
|
||||
|
||||
using HstuPipelineProblem = HstuPipelineProblemTemp<HstuTraits>;
|
||||
|
||||
using HstuEpilogue = ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::OaccDataType,
|
||||
typename HstuAttentionFwdTypeConfig<InOutDataType>::ODataType,
|
||||
kPadSeqLenQ,
|
||||
kPadHeadDimV>>;
|
||||
|
||||
using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS<HstuPipelineProblem>;
|
||||
using HstuKernel = ck_tile::HstuAttentionFwdKernel<HstuPipeline, HstuEpilogue>;
|
||||
|
||||
RunWithKernel<HstuKernel>(param, stream);
|
||||
});
|
||||
};
|
||||
|
||||
template <typename HstuKernel>
|
||||
static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const auto kargs = [&] {
|
||||
return HstuKernel::MakeKargs(param.q_ptr,
|
||||
param.k_ptr,
|
||||
param.v_ptr,
|
||||
param.bias_ptr,
|
||||
param.o_ptr,
|
||||
param.seq_offsets_ptr,
|
||||
param.hdim_qk,
|
||||
param.hdim_v,
|
||||
param.num_head,
|
||||
param.scale_s,
|
||||
param.seq_stride_q,
|
||||
param.seq_stride_k,
|
||||
param.seq_stride_v,
|
||||
param.seq_stride_bias,
|
||||
param.seq_stride_o,
|
||||
param.nhead_stride_q,
|
||||
param.nhead_stride_k,
|
||||
param.nhead_stride_v,
|
||||
param.nhead_stride_bias,
|
||||
param.nhead_stride_o,
|
||||
param.num_targets_ptr,
|
||||
param.window_size,
|
||||
param.contextual_seqlen,
|
||||
param.min_full_attn_seqlen,
|
||||
param.p_drop,
|
||||
param.philox_seed,
|
||||
param.philox_offset);
|
||||
}();
|
||||
|
||||
dim3 kGridSize =
|
||||
HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v);
|
||||
constexpr dim3 kBlockSize = HstuKernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu;
|
||||
|
||||
(void)ck_tile::launch_kernel(ck_tile::stream_config{stream, false},
|
||||
ck_tile::make_kernel<kBlockSize.x, kBlockPerCu>(
|
||||
HstuKernel{}, kGridSize, kBlockSize, 0, kargs));
|
||||
};
|
||||
};
|
||||
|
||||
template <typename InOutDataType,
|
||||
bool kUseCausal,
|
||||
bool kUseLocal,
|
||||
bool kHasBias,
|
||||
bool kHasDropout,
|
||||
ck_tile::index_t MaxK>
|
||||
void run_jagged_forward_causal_local_bias_dropout_dispatch(HstuAttentionFwdParams& param,
|
||||
hipStream_t stream)
|
||||
{
|
||||
jagged_forward_causal_local_bias_dropout_dispatch<InOutDataType,
|
||||
kUseCausal,
|
||||
kUseLocal,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>::Run(param, stream);
|
||||
};
|
||||
@@ -0,0 +1,43 @@
|
||||
/*
|
||||
* Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
#include <ck_tile/core.hpp>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "hstu_attention_bool_switch.hpp"
|
||||
#include "hstu_attention_hdim_switch.hpp"
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
#include "instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp"
|
||||
|
||||
void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream)
|
||||
{
|
||||
const bool has_dropout = (param.p_drop > 0.0f);
|
||||
const bool has_bias = (param.bias_ptr != nullptr);
|
||||
const bool use_causal = param.use_causal;
|
||||
BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] {
|
||||
HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] {
|
||||
if(param.window_size > 0)
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
true,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
run_jagged_forward_causal_local_bias_dropout_dispatch<ck_tile::fp16_t,
|
||||
kUseCausal,
|
||||
false,
|
||||
kHasBias,
|
||||
kHasDropout,
|
||||
MaxK>(param, stream);
|
||||
};
|
||||
});
|
||||
});
|
||||
};
|
||||
57
example/ck_tile/18_hstu_attention/hstu_attention_params.hpp
Normal file
57
example/ck_tile/18_hstu_attention/hstu_attention_params.hpp
Normal file
@@ -0,0 +1,57 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
struct HstuAttentionFwdParams
|
||||
{
|
||||
bool is_jagged;
|
||||
|
||||
ck_tile::index_t num_batch;
|
||||
ck_tile::index_t seqlen; // batched mode only
|
||||
const void* seq_offsets_ptr; // jagged mode only
|
||||
ck_tile::index_t max_seqlen; // jagged mode only
|
||||
|
||||
const void* q_ptr;
|
||||
const void* k_ptr;
|
||||
const void* v_ptr;
|
||||
const void* bias_ptr;
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t hdim_qk;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_head;
|
||||
float scale_s;
|
||||
|
||||
ck_tile::index_t seq_stride_q;
|
||||
ck_tile::index_t seq_stride_k;
|
||||
ck_tile::index_t seq_stride_v;
|
||||
ck_tile::index_t seq_stride_bias;
|
||||
ck_tile::index_t seq_stride_o;
|
||||
|
||||
ck_tile::index_t nhead_stride_q;
|
||||
ck_tile::index_t nhead_stride_k;
|
||||
ck_tile::index_t nhead_stride_v;
|
||||
ck_tile::index_t nhead_stride_bias;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
// batched mode only parameters
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_bias;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
|
||||
const void* num_targets_ptr;
|
||||
|
||||
bool use_causal;
|
||||
ck_tile::index_t window_size;
|
||||
ck_tile::index_t contextual_seqlen;
|
||||
ck_tile::index_t min_full_attn_seqlen;
|
||||
|
||||
float p_drop;
|
||||
uint64_t philox_seed;
|
||||
uint64_t philox_offset;
|
||||
};
|
||||
@@ -0,0 +1,62 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// PipelineProblem encodes information not only from the original user-problem,
|
||||
// but it also contains other information needed by the pipeline, which includes
|
||||
// TileShape -- which determines how block-layer calculation is done in tiles and
|
||||
// how warps are allocated on dimensions
|
||||
// Traits -- other information required for running the kernel and pipeline
|
||||
|
||||
template <typename InOutDataType_,
|
||||
typename GemmAccDataType_,
|
||||
typename CompDataType_, // data type for SiLU and other non-linear calculation
|
||||
typename BiasDataType_,
|
||||
bool kIsJagged_,
|
||||
bool kHasBias_,
|
||||
bool kHasDropout_,
|
||||
typename HstuMask_, // encoding Causal and Local, contextual masking
|
||||
typename AttentionTileShape_,
|
||||
typename Traits_>
|
||||
struct HstuAttentionFwdPipelineProblem
|
||||
{
|
||||
using InOutDataType = remove_cvref_t<InOutDataType_>;
|
||||
using QKVDataType = InOutDataType;
|
||||
using ODataType = InOutDataType;
|
||||
using GemmAccDataType = remove_cvref_t<GemmAccDataType_>;
|
||||
|
||||
// DataType used when siLU calculation
|
||||
using CompDataType = remove_cvref_t<CompDataType_>;
|
||||
using BiasDataType = remove_cvref_t<BiasDataType_>;
|
||||
|
||||
// to be compatible with ck_tile existing policy codes
|
||||
using QDataType = QKVDataType;
|
||||
using KDataType = QKVDataType;
|
||||
using VDataType = QKVDataType;
|
||||
using SaccDataType = GemmAccDataType;
|
||||
using OaccDataType = GemmAccDataType;
|
||||
using PDataType = QKVDataType;
|
||||
|
||||
static constexpr bool kIsJagged = kIsJagged_;
|
||||
static constexpr bool kHasBias = kHasBias_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
|
||||
using HstuMask = remove_cvref_t<HstuMask_>;
|
||||
|
||||
using HstuAttentionTileShape = remove_cvref_t<AttentionTileShape_>;
|
||||
|
||||
// Keep the name compatible with ck_tile existing policy codes, to be changed
|
||||
using BlockFmhaShape = HstuAttentionTileShape;
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
static constexpr index_t kNumGemm0Warps = AttentionTileShape_::NumGemm0Warps;
|
||||
static constexpr index_t kNumGemm1Warps = AttentionTileShape_::NumGemm1Warps;
|
||||
static constexpr index_t kBlockSize = AttentionTileShape_::NumWarps * get_warp_size();
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,24 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
|
||||
// Type configuration
|
||||
template <typename DataType>
|
||||
struct HSTUAttentionTypeConfig;
|
||||
|
||||
template <>
|
||||
struct HSTUAttentionTypeConfig<ck_tile::fp16_t>
|
||||
{
|
||||
using GemmAccDataType = float;
|
||||
using SMComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct HSTUAttentionTypeConfig<ck_tile::bf16_t>
|
||||
{
|
||||
using GemmAccDataType = float;
|
||||
using SMComputeDataType = float;
|
||||
};
|
||||
25
example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp
Normal file
25
example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp
Normal file
@@ -0,0 +1,25 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadSeqLenQ_,
|
||||
bool kPadSeqLenK_,
|
||||
bool kPadHeadDimQK_,
|
||||
bool kPadHeadDimV_,
|
||||
index_t kBlockPerCu_>
|
||||
struct HstuAttentionFwdTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQK = kPadHeadDimQK_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
109
example/ck_tile/18_hstu_attention/hstu_block_masking.hpp
Normal file
109
example/ck_tile/18_hstu_attention/hstu_block_masking.hpp
Normal file
@@ -0,0 +1,109 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kUseCausal, bool kUseLocal>
|
||||
struct HstuBlockMasking
|
||||
{
|
||||
static constexpr bool IsMasking = (kUseCausal || kUseLocal);
|
||||
|
||||
int max_attn_len;
|
||||
int contextual_seqlen;
|
||||
int min_full_attn_seqlen;
|
||||
int max_uih_len;
|
||||
|
||||
CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_,
|
||||
int contextual_seqlen_,
|
||||
int min_full_attn_seqlen_,
|
||||
int max_uih_len_)
|
||||
{
|
||||
max_attn_len = max_attn_len_;
|
||||
contextual_seqlen = contextual_seqlen_;
|
||||
min_full_attn_seqlen = min_full_attn_seqlen_;
|
||||
max_uih_len = max_uih_len_;
|
||||
};
|
||||
|
||||
// to get the loop length along X axis, return index:[start, end), end-start=length
|
||||
// use this if need loop over X axis tile by tile (eg. seqlen_k loop-over)
|
||||
template <index_t YTile, index_t XTile>
|
||||
CK_TILE_HOST_DEVICE constexpr auto
|
||||
GetTileRangeAlongX(index_t i_y, number<YTile>, number<XTile>) const
|
||||
{
|
||||
if constexpr(!IsMasking)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
if(contextual_seqlen > 0 && (i_y < contextual_seqlen))
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
|
||||
if constexpr(kUseCausal && !kUseLocal)
|
||||
{
|
||||
index_t x_end =
|
||||
min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y
|
||||
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
}
|
||||
else if constexpr(!kUseCausal && kUseLocal)
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_start = max(0, i_y - max_attn_len);
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
|
||||
return ck_tile::make_tuple(x_start - x_start % XTile, x_end);
|
||||
};
|
||||
}
|
||||
else // kUseCausal && kUseLocal
|
||||
{
|
||||
if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen)
|
||||
{
|
||||
return ck_tile::make_tuple(0, max_uih_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
index_t x_end = i_y + YTile + max_attn_len;
|
||||
|
||||
return ck_tile::make_tuple(0, x_end);
|
||||
};
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col)
|
||||
{
|
||||
if(row < contextual_seqlen)
|
||||
return true;
|
||||
|
||||
bool result = false;
|
||||
if constexpr(kUseLocal)
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col) && (row - col <= max_attn_len);
|
||||
else
|
||||
result = std::abs(row - col) <= max_attn_len;
|
||||
|
||||
if(min_full_attn_seqlen > 0)
|
||||
result = result || (row >= max_uih_len - min_full_attn_seqlen);
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(kUseCausal)
|
||||
result = (row >= col);
|
||||
};
|
||||
|
||||
return result;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,206 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,206 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_batched_forward_dispatch.hpp"
|
||||
|
||||
template void run_batched_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,206 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
|
||||
extern template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::bf16_t,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
256>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
64>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
@@ -0,0 +1,22 @@
|
||||
|
||||
/*
|
||||
Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
*
|
||||
* This source code is licensed under the BSD-style license found in the
|
||||
* LICENSE file in the root directory of this source tree.
|
||||
*
|
||||
* The file is automatically generated, don't modify!
|
||||
* See the generator script
|
||||
* `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py`
|
||||
*/
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include "hstu_attention_jagged_forward_dispatch.hpp"
|
||||
|
||||
template void run_jagged_forward_causal_local_bias_dropout_dispatch<
|
||||
ck_tile::fp16_t,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
128>(HstuAttentionFwdParams& param, hipStream_t stream);
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user