diff --git a/example/ck_tile/18_hstu_attention/CMakeLists.txt b/example/ck_tile/18_hstu_attention/CMakeLists.txt index b62b32e14e..8d1bfa0d8a 100644 --- a/example/ck_tile/18_hstu_attention/CMakeLists.txt +++ b/example/ck_tile/18_hstu_attention/CMakeLists.txt @@ -2,10 +2,11 @@ set(EXAMPLE_HSTU_ATTENTION "tile_example_hstu_attention") # not using add_example_executable() to add this target, since we don't want this to have # to be included in "make all/install/check" message("adding example ${EXAMPLE_HSTU_ATTENTION}") -##file(GLOB INSTANCE_SRCS instances/*.cpp) +file(GLOB INSTANCE_SRCS instances/*.cpp) +set(INTERFACES_SRCS hstu_attention_jagged_forward_bf16.cpp hstu_attention_jagged_forward_fp16.cpp hstu_attention_batched_forward_bf16.cpp hstu_attention_batched_forward_fp16.cpp) add_executable(${EXAMPLE_HSTU_ATTENTION} EXCLUDE_FROM_ALL example_hstu_attention.cpp) target_include_directories(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) -##target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE hstu_attention_bf16.cpp hstu_attention_fp16.cpp ${INSTANCE_SRCS}) +target_sources(${EXAMPLE_HSTU_ATTENTION} PRIVATE ${INTERFACES_SRCS} ${INSTANCE_SRCS}) set(EXAMPLE_HSTU_ATTENTION_COMPILE_OPTIONS) diff --git a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp index 1c46b7f846..7016746ebd 100644 --- a/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp +++ b/example/ck_tile/18_hstu_attention/example_hstu_attention.cpp @@ -22,10 +22,16 @@ #include #include -#include "hstu_attention_setting.hpp" -#include "bool_switch.hpp" +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_params.hpp" #include "reference_hstu_attention.hpp" +extern void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); +extern void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream); +extern void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream); +extern void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream); + template std::ostream& operator<<(std::ostream& os, const std::vector& v) { @@ -120,25 +126,21 @@ bool run(const ck_tile::ArgParser& arg_parser) bool do_validation = static_cast(arg_parser.get_int("v")); bool is_jagged = static_cast(arg_parser.get_int("jagged")); int num_batch = arg_parser.get_int("b"); - int nhead = arg_parser.get_int("nhead"); + int num_head = arg_parser.get_int("nhead"); int hdim_qk = arg_parser.get_int("hdim_qk"); int hdim_v = arg_parser.get_int("hdim_v"); bool use_causal = static_cast(arg_parser.get_int("causal")); - int max_attn_len = arg_parser.get_int("local_len"); + int window_size = arg_parser.get_int("local_len"); - bool use_local = (max_attn_len > 0); + bool use_local = (window_size > 0); - int contextual_seq_len = arg_parser.get_int("context_len"); - int min_full_seq_len = arg_parser.get_int("minfull_len"); - - int seed = arg_parser.get_int("seed"); + int contextual_seqlen = arg_parser.get_int("context_len"); + int min_full_attn_seqlen = arg_parser.get_int("minfull_len"); + int seed = arg_parser.get_int("seed"); bool measure_perf = static_cast(arg_parser.get_int("perf")); - (void)do_validation; - (void)measure_perf; - std::string str_of_targets = arg_parser.get_str("targets"); std::vector num_targets = get_integers_from_string(str_of_targets); @@ -147,7 +149,8 @@ bool run(const ck_tile::ArgParser& arg_parser) std::vector seq_offsets; - int seqlen = 0; // means total seq lengths for jagged + int seqlen = 0; // means total seq lengths for jagged + int max_seqlen = 0; if(is_jagged) { @@ -156,6 +159,7 @@ bool run(const ck_tile::ArgParser& arg_parser) seq_offsets.push_back(0); for(size_t i = 0; i < seq_lengths.size(); i++) { + max_seqlen = max(max_seqlen, seq_lengths[i]); seqlen += seq_lengths[i]; seq_offsets.push_back(seqlen); }; @@ -166,16 +170,16 @@ bool run(const ck_tile::ArgParser& arg_parser) for(size_t i = 0; i < seq_lengths.size(); i++) { - assert(seq_lengths[i] - num_targets[i] >= min_full_seq_len); - assert(seq_lengths[i] - num_targets[i] >= contextual_seq_len); + assert(seq_lengths[i] - num_targets[i] >= min_full_attn_seqlen); + assert(seq_lengths[i] - num_targets[i] >= contextual_seqlen); }; } else { for(size_t i = 0; i < seq_lengths.size(); i++) { - assert(seq_lengths[i] >= min_full_seq_len); - assert(seq_lengths[i] >= contextual_seq_len); + assert(seq_lengths[i] >= min_full_attn_seqlen); + assert(seq_lengths[i] >= contextual_seqlen); }; }; } @@ -188,53 +192,212 @@ bool run(const ck_tile::ArgParser& arg_parser) { assert(1 == num_targets.size()); - assert(seqlen - num_targets[0] >= min_full_seq_len); - assert(seqlen - num_targets[0] >= contextual_seq_len); + assert(seqlen - num_targets[0] >= min_full_attn_seqlen); + assert(seqlen - num_targets[0] >= contextual_seqlen); } else { - assert(seqlen >= min_full_seq_len); - assert(seqlen >= contextual_seq_len); + assert(seqlen >= min_full_attn_seqlen); + assert(seqlen >= contextual_seqlen); }; }; int batches_for_alloc = is_jagged ? 1 : num_batch; ck_tile::HostTensor q_host( - std::array{batches_for_alloc, seqlen, nhead, hdim_qk}); + std::array{batches_for_alloc, seqlen, num_head, hdim_qk}); ck_tile::HostTensor k_host( - std::array{batches_for_alloc, seqlen, nhead, hdim_qk}); + std::array{batches_for_alloc, seqlen, num_head, hdim_qk}); ck_tile::HostTensor v_host( - std::array{batches_for_alloc, seqlen, nhead, hdim_v}); + std::array{batches_for_alloc, seqlen, num_head, hdim_v}); ck_tile::HostTensor o_host_ref( - std::array{batches_for_alloc, seqlen, nhead, hdim_v}); + std::array{batches_for_alloc, seqlen, num_head, hdim_v}); ck_tile::FillNormalDistributionIntegerValue{-2.f, 2.f, seed}(q_host); ck_tile::FillNormalDistributionIntegerValue{-2.f, 2.f, seed}(k_host); ck_tile::FillNormalDistributionIntegerValue{-2.f, 2.f, seed}(v_host); - using GemmAccDataType = typename HSTUAttentionTypeConfig::GemmAccDataType; - using SMComputeDataType = typename HSTUAttentionTypeConfig::SMComputeDataType; + ck_tile::DeviceMem q_dev(q_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem k_dev(k_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem v_dev(v_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem o_dev(o_host_ref.get_element_space_size_in_bytes()); - BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] { - ck_tile::reference_hstu_attention::Run(q_host, - k_host, - v_host, - o_host_ref, - num_batch, - 1.0f, - seq_offsets, - num_targets, - max_attn_len, - contextual_seq_len, - min_full_seq_len); - }); - return 0; + ck_tile::DeviceMem seq_offsets_dev(seq_offsets.size() * sizeof(int)); + ck_tile::DeviceMem num_targets_dev(num_targets.size() * sizeof(int)); + + q_dev.ToDevice(q_host.data()); + k_dev.ToDevice(k_host.data()); + v_dev.ToDevice(v_host.data()); + + if(is_jagged) + seq_offsets_dev.ToDevice(seq_offsets.data()); + if(!num_targets.empty()) + num_targets_dev.ToDevice(num_targets.data()); + + HstuAttentionFwdParams params; + + if(is_jagged) + { + params.is_jagged = true; + params.num_batch = num_batch; + params.seq_offsets_ptr = seq_offsets_dev.GetDeviceBuffer(); + params.max_seqlen = max_seqlen; + params.q_ptr = q_dev.GetDeviceBuffer(); + params.k_ptr = k_dev.GetDeviceBuffer(); + params.v_ptr = v_dev.GetDeviceBuffer(); + params.bias_ptr = nullptr; + params.o_ptr = o_dev.GetDeviceBuffer(); + params.hdim_qk = hdim_qk; + params.hdim_v = hdim_v; + params.num_head = num_head; + params.scale_s = 1.0f / std::sqrt(params.hdim_qk); + params.seq_stride_q = q_host.get_strides()[1]; + params.seq_stride_k = k_host.get_strides()[1]; + params.seq_stride_v = v_host.get_strides()[1]; + params.seq_stride_bias = 0; + params.seq_stride_o = o_host_ref.get_strides()[1]; + params.nhead_stride_q = q_host.get_strides()[2]; + params.nhead_stride_k = k_host.get_strides()[2]; + params.nhead_stride_v = v_host.get_strides()[2]; + params.nhead_stride_bias = 0; + params.nhead_stride_o = o_host_ref.get_strides()[2]; + params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); + params.use_causal = use_causal; + params.window_size = window_size; + params.contextual_seqlen = contextual_seqlen; + params.min_full_attn_seqlen = min_full_attn_seqlen; + params.p_drop = 0.0f; // dropout is not supported at present + params.philox_seed = 0UL; + params.philox_offset = 0UL; + } + else + { + params.is_jagged = false; + params.num_batch = num_batch; + params.seqlen = seqlen; + params.q_ptr = q_dev.GetDeviceBuffer(); + params.k_ptr = k_dev.GetDeviceBuffer(); + params.v_ptr = v_dev.GetDeviceBuffer(); + params.bias_ptr = nullptr; + params.o_ptr = o_dev.GetDeviceBuffer(); + params.hdim_qk = hdim_qk; + params.hdim_v = hdim_v; + params.num_head = num_head; + params.scale_s = 1.0f / std::sqrt(params.hdim_qk); + params.seq_stride_q = q_host.get_strides()[1]; + params.seq_stride_k = k_host.get_strides()[1]; + params.seq_stride_v = v_host.get_strides()[1]; + params.seq_stride_bias = 0; + params.seq_stride_o = o_host_ref.get_strides()[1]; + params.nhead_stride_q = q_host.get_strides()[2]; + params.nhead_stride_k = k_host.get_strides()[2]; + params.nhead_stride_v = v_host.get_strides()[2]; + params.nhead_stride_bias = 0; + params.nhead_stride_o = o_host_ref.get_strides()[2]; + params.batch_stride_q = q_host.get_strides()[0]; + params.batch_stride_k = k_host.get_strides()[0]; + params.batch_stride_v = v_host.get_strides()[0]; + params.batch_stride_bias = 0; + params.batch_stride_o = o_host_ref.get_strides()[0]; + params.num_targets_ptr = num_targets.empty() ? nullptr : num_targets_dev.GetDeviceBuffer(); + params.use_causal = use_causal; + params.window_size = window_size; + params.contextual_seqlen = contextual_seqlen; + params.min_full_attn_seqlen = min_full_attn_seqlen; + params.p_drop = 0.0f; // dropout is not supported at present + params.philox_seed = 0UL; + params.philox_offset = 0UL; + }; + + hipStream_t stream; + + HIP_CHECK_ERROR(hipStreamCreate(&stream)); + + if constexpr(std::is_same::value) + { + if(is_jagged) + hstu_attention_jagged_forward_fp16(params, stream); + else + hstu_attention_batched_forward_fp16(params, stream); + } + else if constexpr(std::is_same::value) + { + if(is_jagged) + hstu_attention_jagged_forward_bf16(params, stream); + else + hstu_attention_batched_forward_bf16(params, stream); + } + else + throw std::runtime_error("Other data type is not supported at present!"); + + bool res = true; + + if(do_validation) + { + using GemmAccDataType = typename HstuAttentionFwdTypeConfig::GemmAccDataType; + using CompDataType = typename HstuAttentionFwdTypeConfig::CompDataType; + + BOOL_SWITCH_3(is_jagged, kIsJagged, use_causal, kUseCausal, use_local, kUseLocal, [&] { + ck_tile::reference_hstu_attention::Run(q_host, + k_host, + v_host, + o_host_ref, + num_batch, + 1.0f, + seq_offsets, + num_targets, + window_size, + contextual_seqlen, + min_full_attn_seqlen); + }); + + ck_tile::HostTensor o_host( + std::array{batches_for_alloc, seqlen, num_head, hdim_v}); + + o_dev.FromDevice(o_host.data()); + + auto [rtol, atol] = get_elimit(); + + res = ck_tile::check_err( + o_host, o_host_ref, std::string("hstu_attention output error"), atol, rtol); + }; + + if(measure_perf) + { + ck_tile::gpu_timer timer{}; + + timer.start(stream); + for(int i = 0; i < 20; i++) + { + if constexpr(std::is_same::value) + { + if(is_jagged) + hstu_attention_jagged_forward_fp16(params, stream); + else + hstu_attention_batched_forward_fp16(params, stream); + } + else if constexpr(std::is_same::value) + { + if(is_jagged) + hstu_attention_jagged_forward_bf16(params, stream); + else + hstu_attention_batched_forward_bf16(params, stream); + } + } + timer.stop(stream); + + auto ms = timer.duration() / 20.f; + + std::cout << "Average execution time of the gather_attention operator is " << ms + << " milli-seconds" << std::endl; + } + + return res; } int main(int argc, char* argv[]) diff --git a/example/ck_tile/18_hstu_attention/generate_instances.py b/example/ck_tile/18_hstu_attention/generate_instances.py new file mode 100644 index 0000000000..2da0a4bb69 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/generate_instances.py @@ -0,0 +1,177 @@ +# noqa: C801 +# Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +import os +from pathlib import Path +from typing import List + +HSTU_COPYRIGHT_HEADER = """ +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `{file}` + */ +""".format( + file=os.path.relpath(os.path.realpath(__file__), start=Path(__file__).parents[4]) +) + +HSTU_FORWARD_INSTANCE_TEMPLATE_INC = """ +#include +#include \"hstu_attention_{mode}_forward_dispatch.hpp\" +""" + +HSTU_FORWARD_INSTANCE_TEMPLATE = """ +{extern}template void run_{mode}_forward_causal_local_bias_dropout_dispatch< + {dtype}, + {has_causal}, + {has_local}, + {has_bias}, + {has_dropout}, + {max_k}>(HstuAttentionFwdParams& param, hipStream_t stream); +""" + +HSTU_FORWARD_INSTANCE_FNAME = ( + "hstu_attention_{mode}_forward_{dtype_str}_{has_or_no_causal_str}_{has_or_no_local_str}_" + "{has_or_no_bias_str}_{has_or_no_dropout_str}_{max_k_str}.cpp" +) + +HSTU_INSTANCE_REF_FNAME = "hstu_attention_{mode}_{function}_{dtype}_instances_ref.hpp" + +BOOL_MAP = {True: "true", False: "false"} + +BOOL_MAP_CAUSAL = { + True: "has_causal", + False: "no_causal", +} + +BOOL_MAP_LOCAL = { + True: "has_local", + False: "no_local", +} + +BOOL_MAP_BIAS = { + True: "has_bias", + False: "no_bias", +} + +BOOL_MAP_DROPOUT = { + True: "has_dropout", + False: "no_dropout", +} + +INT_MAP_MAX_K = {hd: f"maxk_{hd}" for hd in [64, 128, 256]} + +TYPE_CTYPE_MAP = { + "fp16": "ck_tile::fp16_t", + "bf16": "ck_tile::bf16_t", +} + +TYPE_FNAME_MAP = { + "fp16": "bfloat16", + "bf16": "half", +} + +MODE_NAME_MAP = { + "batched": "Batched", + "jagged": "Jagged", +} + +def create_forward_instances(instance_dir: Path, headdims: List) -> None: + for mode in ["batched", "jagged"]: + for dtype in ["fp16", "bf16"]: + for has_causal, has_local in zip([True, False],[True, False]): + for has_bias in [True, False]: + for has_dropout in [True, False]: + for max_k in headdims: + fname = HSTU_FORWARD_INSTANCE_FNAME.format( + mode=mode, + dtype_str=dtype, + has_or_no_causal_str=BOOL_MAP_CAUSAL[has_causal], + has_or_no_local_str=BOOL_MAP_CAUSAL[has_local], + has_or_no_bias_str=BOOL_MAP_BIAS[has_bias], + has_or_no_dropout_str=BOOL_MAP_DROPOUT[has_dropout], + max_k_str=INT_MAP_MAX_K[max_k], + ) + forward_instance_inc = ( + HSTU_FORWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + ) + forward_instance = HSTU_FORWARD_INSTANCE_TEMPLATE.format( + extern="", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causal=BOOL_MAP[has_causal], + has_local=BOOL_MAP[has_causal], + has_bias=BOOL_MAP[has_bias], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) + (instance_dir / fname).write_text( + HSTU_COPYRIGHT_HEADER + + forward_instance_inc + + forward_instance + ) + + +def create_forward_instances_ref(instance_dir: Path, headdims: List) -> None: + for mode in ["batched", "jagged"]: + for dtype in ["fp16", "bf16"]: + ref_fname = HSTU_INSTANCE_REF_FNAME.format( + mode=mode, + function="forward", + dtype=dtype, + ) + ref_fname_path = instance_dir / ref_fname + forward_instance_inc = HSTU_FORWARD_INSTANCE_TEMPLATE_INC.format( + mode=mode, + dtype_file=TYPE_FNAME_MAP[dtype], + ) + with open(ref_fname_path, "a") as file: + file.write(HSTU_COPYRIGHT_HEADER) + file.write(forward_instance_inc) + for max_k in headdims: + for has_bias in [True, False]: + for has_dropout in [True, False]: + for has_causal, has_local in zip([True, False],[True, False]): + forward_instance = ( + HSTU_FORWARD_INSTANCE_TEMPLATE.format( + extern="extern ", + mode=mode, + dtype=TYPE_CTYPE_MAP[dtype], + has_causal=BOOL_MAP[has_causal], + has_local=BOOL_MAP[has_local], + has_bias=BOOL_MAP[has_bias], + has_dropout=BOOL_MAP[has_dropout], + max_k=max_k, + cap_mode=MODE_NAME_MAP[mode], + ) + ) + file.write(forward_instance) + +if __name__ == "__main__": + headdims_fwd = [64, 128, 256] + + this_dir = os.path.dirname(__file__) + output_dir = Path(this_dir) / "instances" + output_dir.mkdir(parents=True, exist_ok=True) + + # remove existing files in the directory + files = os.listdir(output_dir) + for ff in files: + file_path = os.path.join(output_dir, ff) + os.remove(file_path) + + create_forward_instances(output_dir, headdims_fwd) + create_forward_instances_ref(output_dir, headdims_fwd) diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp new file mode 100644 index 0000000000..0a9bdbac6d --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_bf16.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_batched_forward_dispatch.hpp" + +#include "instances/hstu_attention_batched_forward_bf16_instances_ref.hpp" + +void hstu_attention_batched_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream) +{ + const bool has_dropout = (param.p_drop > 0.0f); + const bool has_bias = (param.bias_ptr != nullptr); + const bool use_causal = param.use_causal; + BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { + HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { + if(param.window_size > 0) + { + run_batched_forward_causal_local_bias_dropout_dispatch(param, stream); + } + else + { + run_batched_forward_causal_local_bias_dropout_dispatch(param, stream); + }; + }); + }); +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp new file mode 100644 index 0000000000..d7c68e7929 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_dispatch.hpp @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_fwd_setting.hpp" +#include "hstu_attention_params.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_block_masking.hpp" +#include "hstu_attention_pipeline_problem.hpp" +#include "hstu_attention_traits.hpp" +#include "hstu_attention_fwd_pipeline.hpp" +#include "hstu_attention_fwd_kernel.hpp" + +template +struct batched_forward_causal_local_bias_dropout_dispatch +{ + using HstuAttentionShape = typename HstuAttentionFwdShape::Type; + using HstuMask = ck_tile::HstuBlockMasking; + + template + using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< + InOutDataType, + typename HstuAttentionFwdTypeConfig::GemmAccDataType, + typename HstuAttentionFwdTypeConfig::CompDataType, + typename HstuAttentionFwdTypeConfig::BiasDataType, + false, // kIsJagged + kHasBias, + kHasDropout, + HstuMask, + HstuAttentionShape, + HstuTraits>; + + static void Run(HstuAttentionFwdParams& param, hipStream_t stream) + { + constexpr ck_tile::index_t occupancy = -1; + + const bool pad_seqlen_k = !(param.seqlen % HstuAttentionShape::kN0 == 0); + const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + BOOL_SWITCH_3( + pad_seqlen_k, + kPadSeqLenK, + pad_headdim_qk, + kPadHeadDimQK, + pad_headdim_v, + kPadHeadDimV, + [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; + + using HstuPipelineProblem = HstuPipelineProblemTemp; + + using HstuEpilogue = ck_tile::Default2DEpilogue::OaccDataType, + typename HstuAttentionFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS; + using HstuKernel = ck_tile::HstuAttentionFwdKernel; + + RunWithKernel(param, stream); + }); + }; + + template + static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream) + { + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.q_ptr, + param.k_ptr, + param.v_ptr, + param.bias_ptr, + param.o_ptr, + param.seqlen, + param.hdim_qk, + param.hdim_v, + param.num_head, + param.scale_s, + param.seq_stride_q, + param.seq_stride_k, + param.seq_stride_v, + param.seq_stride_bias, + param.seq_stride_o, + param.nhead_stride_q, + param.nhead_stride_k, + param.nhead_stride_v, + param.nhead_stride_bias, + param.nhead_stride_o, + param.batch_stride_q, + param.batch_stride_k, + param.batch_stride_v, + param.batch_stride_bias, + param.batch_stride_o, + param.num_targets_ptr, + param.window_size, + param.contextual_seqlen, + param.min_full_attn_seqlen, + param.p_drop, + param.philox_seed, + param.philox_offset); + }(); + + dim3 kGridSize = + HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel(ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; + +template +void run_batched_forward_causal_local_bias_dropout_dispatch(HstuAttentionFwdParams& param, + hipStream_t stream) +{ + batched_forward_causal_local_bias_dropout_dispatch::Run(param, stream); +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp new file mode 100644 index 0000000000..8fd791380d --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_batched_forward_fp16.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_batched_forward_dispatch.hpp" + +#include "instances/hstu_attention_batched_forward_fp16_instances_ref.hpp" + +void hstu_attention_batched_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream) +{ + const bool has_dropout = (param.p_drop > 0.0f); + const bool has_bias = (param.bias_ptr != nullptr); + const bool use_causal = param.use_causal; + BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { + HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { + if(param.window_size > 0) + { + run_batched_forward_causal_local_bias_dropout_dispatch(param, stream); + } + else + { + run_batched_forward_causal_local_bias_dropout_dispatch(param, stream); + }; + }); + }); +}; diff --git a/example/ck_tile/18_hstu_attention/bool_switch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_bool_switch.hpp similarity index 100% rename from example/ck_tile/18_hstu_attention/bool_switch.hpp rename to example/ck_tile/18_hstu_attention/hstu_attention_bool_switch.hpp diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp new file mode 100644 index 0000000000..dcaeca2665 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_kernel.hpp @@ -0,0 +1,763 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +#include +#include +#include +#include + +// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q] +// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1] +// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k] +// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k]) +// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k] + +namespace ck_tile { + +template +struct HstuAttentionFwdKernel +{ + using HstuAttentionPipeline = ck_tile::remove_cvref_t; + using EpiloguePipeline = ck_tile::remove_cvref_t; + static constexpr ck_tile::index_t kBlockSize = HstuAttentionPipeline::kBlockSize; + static constexpr ck_tile::index_t kBlockPerCu = HstuAttentionPipeline::kBlockPerCu; + static_assert(kBlockPerCu > 0); + static constexpr ck_tile::index_t kBlockPerCuInput = + HstuAttentionPipeline::Problem::kBlockPerCu; + + using QKVDataType = ck_tile::remove_cvref_t; + using BiasDataType = ck_tile::remove_cvref_t; + using ODataType = ck_tile::remove_cvref_t; + + using VLayout = ck_tile::remove_cvref_t; + + static constexpr bool kIsJagged = HstuAttentionPipeline::kIsJagged; + static constexpr bool kPadSeqLenQ = HstuAttentionPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = HstuAttentionPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQK = HstuAttentionPipeline::kPadHeadDimQK; + static constexpr bool kPadHeadDimV = HstuAttentionPipeline::kPadHeadDimV; + static constexpr auto kHasBias = HstuAttentionPipeline::kHasBias; + static constexpr bool kHasDropout = HstuAttentionPipeline::kHasDropout; + using HstuMask = ck_tile::remove_cvref_t; + static constexpr bool kHasMask = HstuMask::IsMasking; + + template // to avoid duplicated base class problem, introduce an template + // arg + struct HstuAttentionFwdEmptyKargs + { + }; + + // kargs use aggregate initializer, so no constructor will provided + // use inheritance to minimize karg size + // user need to use MakeKargs() function to create kargs. + struct HstuAttentionFwdCommonKargs + { + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + void* o_ptr; + + ck_tile::index_t seqlen; + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + + ck_tile::index_t num_head; + float scale_s; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + ck_tile::index_t seq_stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_o; + + const int32_t* num_targets_ptr; + }; + + struct HstuAttentionFwdCommonBiasKargs + { + const void* bias_ptr = nullptr; + ck_tile::index_t seq_stride_bias = 0; + ck_tile::index_t nhead_stride_bias = 0; + }; + + struct HstuAttentionFwdBatchModeBiasKargs : HstuAttentionFwdCommonBiasKargs + { + ck_tile::index_t batch_stride_bias = 0; + }; + + struct HstuAttentionFwdMaskKargs + { + ck_tile::index_t window_size; + ck_tile::index_t contextual_seqlen; + ck_tile::index_t min_full_attn_seqlen; + }; + + struct HstuAttentionFwdDropoutSeedOffset + { + uint64_t drop_seed; + uint64_t drop_offset; + }; + + struct HstuAttentionFwdCommonDropoutKargs : HstuAttentionFwdDropoutSeedOffset + { + void init_dropout(float p_drop, uint64_t seed, uint64_t offset) + { + float p_undrop = 1.0 - p_drop; + p_undrop_in_uint8_t = + uint8_t(std::floor(p_undrop * std::numeric_limits::max())); + rp_undrop = 1.0 / p_undrop; + + this->drop_seed = seed; + this->drop_offset = offset; + } + + float rp_undrop = 1; + uint8_t p_undrop_in_uint8_t = std::numeric_limits::max(); + }; + + struct HstuAttentionFwdBatchModeKargs + : HstuAttentionFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_o; + }; + + struct HstuAttentionFwdJaggModeKargs + : HstuAttentionFwdCommonKargs, + std::conditional_t>, + std::conditional_t>, + std::conditional_t> + { + const int32_t* seq_offsets_ptr; + }; + + using Kargs = std:: + conditional_t; + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + ck_tile::index_t seqlen, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_o, + const void* num_targets_ptr, + ck_tile::index_t window_size, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + const std::pair& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + seqlen, + hdim_qk, + hdim_v, + num_head, + scale_s, + seq_stride_q, + seq_stride_k, + seq_stride_v, + seq_stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + reinterpret_cast(num_targets_ptr)}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for dropout + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_o}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.seq_stride_bias = seq_stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + kargs.batch_stride_bias = batch_stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size = window_size; + kargs.contextual_seqlen = contextual_seqlen; + kargs.min_full_attn_seqlen = min_full_attn_seqlen; + } + if constexpr(kHasDropout) + { + auto seed = std::get<0>(drop_seed_offset); + auto offset = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + ck_tile::index_t seqlen, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_o, + ck_tile::index_t batch_stride_q, + ck_tile::index_t batch_stride_k, + ck_tile::index_t batch_stride_v, + ck_tile::index_t batch_stride_bias, + ck_tile::index_t batch_stride_o, + const void* num_targets_ptr, + ck_tile::index_t window_size, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + uint64_t philox_seed, + uint64_t philox_offset) + { + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + o_ptr, + seqlen, + hdim_qk, + hdim_v, + num_head, + scale_s, + seq_stride_q, + seq_stride_k, + seq_stride_v, + seq_stride_bias, + seq_stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + batch_stride_q, + batch_stride_k, + batch_stride_v, + batch_stride_bias, + batch_stride_o, + num_targets_ptr, + window_size, + contextual_seqlen, + min_full_attn_seqlen, + p_drop, + std::make_pair(philox_seed, philox_offset)); + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargsImpl(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + const void* seq_offsets_ptr, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_o, + const void* num_targets_ptr, + ck_tile::index_t window_size, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + const std::pair& drop_seed_offset) + { + Kargs kargs{{q_ptr, + k_ptr, + v_ptr, + o_ptr, + -1, // seqlen will be updated by another pointer + hdim_qk, + hdim_v, + num_head, + scale_s, + seq_stride_q, + seq_stride_k, + seq_stride_v, + seq_stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_o, + reinterpret_cast(num_targets_ptr)}, // args for common karg + {}, // placeholder for bias + {}, // placeholder for mask + {}, // placeholder for dropout + reinterpret_cast(seq_offsets_ptr)}; + + if constexpr(kHasBias) + { + kargs.bias_ptr = bias_ptr; + kargs.seq_stride_bias = seq_stride_bias; + kargs.nhead_stride_bias = nhead_stride_bias; + } + if constexpr(kHasMask) + { + kargs.window_size = window_size; + kargs.contextual_seqlen = contextual_seqlen; + kargs.min_full_attn_seqlen = min_full_attn_seqlen; + } + if constexpr(kHasDropout) + { + auto seed = std::get<0>(drop_seed_offset); + auto offset = std::get<1>(drop_seed_offset); + kargs.init_dropout(p_drop, seed, offset); + } + + return kargs; + } + + template + CK_TILE_HOST static constexpr std::enable_if_t + MakeKargs(const void* q_ptr, + const void* k_ptr, + const void* v_ptr, + const void* bias_ptr, + void* o_ptr, + const void* seq_offsets_ptr, + ck_tile::index_t hdim_qk, + ck_tile::index_t hdim_v, + ck_tile::index_t num_head, + float scale_s, + ck_tile::index_t seq_stride_q, + ck_tile::index_t seq_stride_k, + ck_tile::index_t seq_stride_v, + ck_tile::index_t seq_stride_bias, + ck_tile::index_t seq_stride_o, + ck_tile::index_t nhead_stride_q, + ck_tile::index_t nhead_stride_k, + ck_tile::index_t nhead_stride_v, + ck_tile::index_t nhead_stride_bias, + ck_tile::index_t nhead_stride_o, + const void* num_targets_ptr, + ck_tile::index_t window_size, + ck_tile::index_t contextual_seqlen, + ck_tile::index_t min_full_attn_seqlen, + float p_drop, + uint64_t philox_seed, + uint64_t philox_offset) + { + return MakeKargsImpl(q_ptr, + k_ptr, + v_ptr, + bias_ptr, + o_ptr, + seq_offsets_ptr, + hdim_qk, + hdim_v, + num_head, + scale_s, + seq_stride_q, + seq_stride_k, + seq_stride_v, + seq_stride_bias, + seq_stride_o, + nhead_stride_q, + nhead_stride_k, + nhead_stride_v, + nhead_stride_bias, + nhead_stride_o, + num_targets_ptr, + window_size, + contextual_seqlen, + min_full_attn_seqlen, + p_drop, + std::make_pair(philox_seed, philox_offset)); + } + + CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_, + ck_tile::index_t nhead_, + ck_tile::index_t seqlen_, + ck_tile::index_t hdim_v_) + { + // TODO: this may need tuning + return dim3(ck_tile::integer_divide_ceil(seqlen_, HstuAttentionPipeline::kM0) * + ck_tile::integer_divide_ceil(hdim_v_, HstuAttentionPipeline::kN1), + nhead_, + batch_size_); + } + + CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs) + { + // const index_t num_tile_m0 = seqlen_q / kM0; + const index_t num_tile_n1 = + ck_tile::integer_divide_ceil(kargs.hdim_v, HstuAttentionPipeline::kN1); + + const index_t i_block = blockIdx.x; + const index_t i_nhead = blockIdx.y; + const index_t i_batch = blockIdx.z; + + const auto f = [](index_t dividend, index_t divisor) { + index_t quotient = dividend / divisor; + index_t modulus = dividend - quotient * divisor; + return ck_tile::make_tuple(quotient, modulus); + }; + + const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1); + + return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return ck_tile::max(HstuAttentionPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + // allocate LDS + __shared__ char smem_ptr[GetSmemSize()]; + + // divide problem + const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs); + + const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * HstuAttentionPipeline::kM0); + const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * HstuAttentionPipeline::kN1); + + long_index_t batch_offset_q = 0; + long_index_t batch_offset_k = 0; + long_index_t batch_offset_v = 0; + long_index_t batch_offset_bias = 0; + long_index_t batch_offset_o = 0; + + if constexpr(kIsJagged) + { + // get starting offset for each batch + const long_index_t query_start = kargs.seq_offsets_ptr[i_batch]; + const long_index_t key_start = query_start; + + batch_offset_q = query_start * kargs.seq_stride_q; + batch_offset_k = key_start * kargs.seq_stride_k; + if constexpr(std::is_same_v) + { + batch_offset_v = key_start * kargs.seq_stride_v; + } + else + { + batch_offset_v = key_start; + } + if constexpr(kHasBias) + { + batch_offset_bias = query_start * kargs.seq_stride_bias; + } + batch_offset_o = query_start * kargs.seq_stride_o; + + kargs.seqlen = kargs.seq_offsets_ptr[1] - kargs.seq_offsets_ptr[0]; + + // # of required blocks is different in each groups, terminate unnecessary blocks + // earlier + if(kargs.seqlen <= i_m0) + { + return; + } + } + else + { + batch_offset_q = static_cast(i_batch) * kargs.batch_stride_q; + batch_offset_k = static_cast(i_batch) * kargs.batch_stride_k; + batch_offset_v = static_cast(i_batch) * kargs.batch_stride_v; + if constexpr(kHasBias) + { + batch_offset_bias = static_cast(i_batch) * kargs.batch_stride_bias; + } + batch_offset_o = static_cast(i_batch) * kargs.batch_stride_o; + } + + int max_uih_len = kargs.seqlen; + + if constexpr(kHasMask) + { + if(kargs.contextual_seqlen > 0) + max_uih_len -= kargs.contextual_seqlen - 1; + }; + + if(kargs.num_targets_ptr != nullptr) + { + if constexpr(kIsJagged) + max_uih_len -= kargs.num_targets_ptr[i_batch]; + else + max_uih_len -= kargs.num_targets_ptr[0]; + }; + + HstuMask mask = [&]() { + if constexpr(kHasMask) + return HstuMask{kargs.window_size, + kargs.contextual_seqlen, + kargs.min_full_attn_seqlen, + max_uih_len}; + else + return HstuMask{0, 0, 0, 0}; + }(); + + // for simplicity, batch stride we just modify the pointer + const QKVDataType* q_ptr = reinterpret_cast(kargs.q_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_q + + batch_offset_q; + const QKVDataType* k_ptr = reinterpret_cast(kargs.k_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_k + + batch_offset_k; + const QKVDataType* v_ptr = reinterpret_cast(kargs.v_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_v + + batch_offset_v; + ODataType* o_ptr = reinterpret_cast(kargs.o_ptr) + + static_cast(i_nhead) * kargs.nhead_stride_o + + batch_offset_o; + + // Q/K/V DRAM and DRAM window + const auto q_dram = [&]() { + const auto q_dram_naive = make_naive_tensor_view( + q_ptr, + make_tuple(kargs.seqlen, kargs.hdim_qk), + make_tuple(kargs.seq_stride_q, 1), + number{}, + number<1>{}); + if constexpr(HstuAttentionPipeline::kQLoadOnce) + { + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(q_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + const auto k_dram = [&]() { + const auto k_dram_naive = make_naive_tensor_view( + k_ptr, + make_tuple(kargs.seqlen, kargs.hdim_qk), + make_tuple(kargs.seq_stride_k, 1), + number{}, + number<1>{}); + + return pad_tensor_view(k_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + const auto v_dram = [&]() { + if constexpr(std::is_same_v) + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(kargs.seq_stride_v, 1), + number{}, + number<1>{}); + + const auto v_dram_transposed = + transform_tensor_view(v_dram_naive, + make_tuple(make_pass_through_transform(kargs.hdim_v), + make_pass_through_transform(kargs.seqlen)), + make_tuple(sequence<1>{}, sequence<0>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return pad_tensor_view(v_dram_transposed, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + const auto v_dram_naive = make_naive_tensor_view( + v_ptr, + make_tuple(kargs.hdim_v, kargs.seqlen), + make_tuple(kargs.seq_stride_v, 1), + number{}, + number<1>{}); + + return pad_tensor_view(v_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + auto q_dram_window = make_tile_window( + q_dram, + [&]() { + if constexpr(HstuAttentionPipeline::kQLoadOnce) + return make_tuple(number{}, + number{}); + else + return make_tuple(number{}, + number{}); + }(), + {i_m0, 0}); + + auto k_dram_window = make_tile_window( + k_dram, + make_tuple(number{}, number{}), + {0, 0}); + + auto v_dram_window = make_tile_window( + v_dram, + make_tuple(number{}, number{}), + {i_n1, 0}); + /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove + /// following copy capture of the 'i_nhead' if in C++20 + const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { + constexpr auto bias_dram_window_lengths = make_tuple( + number{}, number{}); + if constexpr(kHasBias) + { + const BiasDataType* bias_ptr = + reinterpret_cast(kargs.bias_ptr) + + static_cast(i_nhead_) * kargs.nhead_stride_bias + + batch_offset_bias; + + const auto bias_dram = [&]() { + const auto bias_dram_naive = make_naive_tensor_view( + bias_ptr, + make_tuple(kargs.seqlen, kargs.seqlen), + make_tuple(kargs.seq_stride_bias, 1), + number{}, + number<1>{}); + + return pad_tensor_view(bias_dram_naive, + bias_dram_window_lengths, + sequence{}); + }(); + + return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0}); + } + else + { + return make_null_tile_window(bias_dram_window_lengths); + } + }(); + + auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() { + if constexpr(kHasDropout) + { + return BlockDropout{i_batch_, + i_nhead_, + kargs.num_head, + kargs.drop_seed, + kargs.drop_offset, + kargs.rp_undrop, + kargs.p_undrop_in_uint8_t, + false}; + } + else + { + return NullBlockDropout{}; + }; + }(); + + auto o_acc_tile = [&]() { + return HstuAttentionPipeline{}(q_dram_window, + k_dram_window, + v_dram_window, + bias_dram_window, + mask, + kargs.scale_s, + smem_ptr, + dropout); + }(); + + // O DRAM and O DRAM window + auto o_dram = [&]() { + const auto o_dram_naive = make_naive_tensor_view( + o_ptr, + make_tuple(kargs.seqlen, kargs.hdim_v), + make_tuple(kargs.seq_stride_o, 1), + number{}, + number<1>{}); + + return pad_tensor_view(o_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); + }(); + + auto o_dram_window = make_tile_window( + o_dram, + make_tuple(number{}, number{}), + {i_m0, i_n1}); + + EpiloguePipeline{}(o_dram_window, o_acc_tile); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp new file mode 100644 index 0000000000..f43f45b4e1 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline.hpp @@ -0,0 +1,548 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/block/block_dropout.hpp" + +#include "hstu_attention_fwd_pipeline_default_policy.hpp" + +namespace ck_tile { + +template +struct HstuAttentionFwdPipelineQRKSVS +{ + using Problem = remove_cvref_t; + using Policy = remove_cvref_t; + using QKVDataType = remove_cvref_t; + using GemmAccDataType = remove_cvref_t; + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + using PDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using HstuMask = remove_cvref_t; + + using HstuAttentionTileShape = remove_cvref_t; + using VLayout = remove_cvref_t; + static constexpr bool kQLoadOnce = true; + static_assert(kQLoadOnce == Policy::QLoadOnce); + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kM0 = HstuAttentionTileShape::kM0; + static constexpr index_t kN0 = HstuAttentionTileShape::kN0; + static constexpr index_t kK0 = HstuAttentionTileShape::kK0; + static constexpr index_t kN1 = HstuAttentionTileShape::kN1; + static constexpr index_t kK1 = HstuAttentionTileShape::kK1; + static constexpr index_t kQKHeaddim = HstuAttentionTileShape::kQKHeaddim; + static constexpr index_t kSubQKHeaddim = HstuAttentionTileShape::kSubQKHeaddim; + + static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!"); + + static constexpr bool kIsJagged = Problem::kIsJagged; + static constexpr auto kHasBias = Problem::kHasBias; + static constexpr bool kHasDropout = Problem::kHasDropout; + + static constexpr bool kPadSeqLenQ = Problem::Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Problem::Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQK = Problem::Traits::kPadHeadDimQK; + static constexpr bool kPadHeadDimV = + (kQKHeaddim < kSubQKHeaddim) ? 1 : Problem::Traits::kPadHeadDimV; + + // last dimension vector length used to create tensor view(and decide buffer_load vector length) + // ... together with tensor distribution. tensor dist should able to overwrite this + static constexpr index_t kAlignmentQ = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentQ(); + static constexpr index_t kAlignmentK = + kPadHeadDimQK ? 1 : Policy::template GetAlignmentK(); + static constexpr index_t kAlignmentV = []() { + if constexpr(std::is_same_v) + return Problem::Traits::kPadHeadDimV ? 1 : Policy::template GetAlignmentV(); + else + return kPadSeqLenK ? 1 : Policy::template GetAlignmentV(); + }(); + + static constexpr index_t kAlignmentO = + kPadHeadDimV ? 1 : Policy::template GetAlignmentO(); + static constexpr index_t kAlignmentBias = + kPadSeqLenK ? 1 : Policy::template GetAlignmentBias(); + + static constexpr index_t kBlockPerCu = []() { + if constexpr(Problem::Traits::kBlockPerCu != -1) + return Problem::Traits::kBlockPerCu; + else + { + if constexpr(kQKHeaddim == 32) + { + return 2; + } + else if constexpr(kQKHeaddim == 64) + { + return 2; + } + else if constexpr(kQKHeaddim == 96 || kQKHeaddim == 128) + { + if constexpr(kHasBias) + return 1; + else + return 2; + } + else if constexpr(kQKHeaddim == 256) + { + return 1; + } + else + { + return 1; + }; + } + }(); + + static constexpr const char* name = "qr_hstu"; + + using DropoutType = std::conditional_t; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*kSubQKHeaddim tile + const QElementFunction& q_element_func, + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*kSubQKHeaddim tile + const KElementFunction& k_element_func, + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const VElementFunction& v_element_func, + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + const BiasElementFunction& bias_element_func, + const SAccElementFunction& s_acc_element_func, + const PComputeElementFunction& p_compute_element_func, + const OAccElementFunction& o_acc_element_func, + HstuMask mask, + float scale_s, + void* smem_ptr, + DropoutType& dropout) const + { + ignore = q_element_func; + ignore = k_element_func; + + static_assert( + std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && + kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + constexpr auto I0 = number<0>{}; + + constexpr index_t k0_loops = kQKHeaddim / kK0; + constexpr index_t k1_loops = kN0 / kK1; + static_assert(2 <= k0_loops); + static_assert(2 <= k1_loops); + + constexpr auto NumKLdsBuffers = Policy::template GetNumKLdsBuffers(); + constexpr auto NumVLdsBuffers = Policy::template GetNumVLdsBuffers(); + constexpr auto NumPrefetchV = Policy::template GetNumPrefetchV(); + + static_assert(NumKLdsBuffers >= 2); + + auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), + q_dram_block_window_tmp.get_window_lengths(), + q_dram_block_window_tmp.get_window_origin(), + Policy::template MakeQRegTileDistribution()); + + const auto q_origin = q_dram_window.get_window_origin(); + const auto [seqlen_k_start, seqlen_k_end] = + mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number{}, number{}); + + auto k_dram_block_window = + make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), + k_dram_block_window_tmp.get_window_lengths(), + {seqlen_k_start, 0}); + + auto k_dram_window = + make_tile_window(k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + Policy::template MakeKDramTileDistribution()); + + auto k_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {0, kK0}); + + auto q_tile = load_tile(q_dram_window); + + __builtin_amdgcn_sched_barrier(0); + + // K tile in LDS + QKVDataType* k_lds_ptr = static_cast(smem_ptr); + auto k_lds = make_tensor_view( + k_lds_ptr, Policy::template MakeKLdsBlockDescriptor()); + auto k_lds_window = make_tile_window( + k_lds, Policy::template MakeKLdsBlockDescriptor().get_lengths(), {0, 0}); + + using k_lds_window_type = + decltype(get_slice_tile(k_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array k_lds_windows; + + static_for<0, NumKLdsBuffers, 1>{}([&](auto i_buf) { + k_lds_windows[i_buf] = get_slice_tile( + k_lds_window, sequence{}, sequence<(i_buf + 1) * kN0, kK0>{}); + }); + + auto v_dram_window = + make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp.get_window_lengths(), + {0, seqlen_k_start}, // TODO: hdim split? + Policy::template MakeVDramTileDistribution()); + // V tile in LDS + auto v_lds = make_tensor_view( + reinterpret_cast(static_cast(smem_ptr) + + Policy::template GetExclusiveKLdsBytes()), + Policy::template MakeVLdsBlockDescriptor()); + auto v_lds_window = make_tile_window( + v_lds, Policy::template MakeVLdsBlockDescriptor().get_lengths(), {0, 0}); + + using v_tile_type = decltype(load_tile(v_dram_window)); + + statically_indexed_array v_tiles; + + using v_lds_window_type = + decltype(get_slice_tile(v_lds_window, sequence<0, 0>{}, sequence{})); + + statically_indexed_array v_lds_windows; + + static_for<0, NumVLdsBuffers, 1>{}([&](auto i_buf) { + v_lds_windows[i_buf] = get_slice_tile( + v_lds_window, sequence{}, sequence<(i_buf + 1) * kN1, kK1>{}); + }); + + // Block GEMM + constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + + using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile()); + auto s_acc = SaccBlockTileType{}; + + // reduction function for softmax + const auto f_silu = [](CompDataType x) { + auto one = ck_tile::type_convert(1.0f); + + auto sigmod_val = one / (one + exp(-x)); + + return sigmod_val * x; + }; + + using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile()); + + // init Oacc, M, L + auto o_acc = OaccBlockTileType{}; + + clear_tile(o_acc); + + const auto num_loops = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + + // check early exit if no work to do + if constexpr(HstuMask::IsMasking || kPadSeqLenK) + { + if(num_loops <= 0) + { + return o_acc; + } + } + + const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); + auto bias_dram_window = + make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), + bias_dram_block_window_tmp.get_window_lengths(), + {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N + Policy::template MakeBiasDramTileDistribution()); + + auto null_randval_window = [&]() { + if constexpr(kHasDropout) + { + const auto null_randval_dram = [&]() { + const auto null_dram_naive = make_naive_tensor_view( + static_cast(nullptr), + make_tuple(1, 1), + make_tuple(1, 1), + number<1>{}, + number<1>{}); + + return pad_tensor_view(null_dram_naive, + make_tuple(number<1>{}, number<1>{}), + sequence{}); + }(); + + return make_tile_window( + null_randval_dram, make_tuple(number<1>{}, number<1>{}), {0, 0}); + } + else + return make_null_tile_window(make_tuple(number<1>{}, number<1>{})); + }(); + + q_tile = tile_elementwise_in(q_element_func, q_tile); + + index_t i_loop = 0; + + do + { + static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { + store_tile(k_lds_windows[number{}], + tile_elementwise_in(k_element_func, k_tile)); + if constexpr(i_k0 == 0) + clear_tile(s_acc); + + if constexpr(i_k0 < k0_loops - 1) + k_tile = load_tile(k_dram_window); + if constexpr(i_k0 < k0_loops - 2) + move_tile_window(k_dram_window, {0, kK0}); + + block_sync_lds(); + // execute current unroll of gemm_0 + gemm_0(s_acc, + get_slice_tile( + q_tile, sequence<0, i_k0 * kK0>{}, sequence{}), + k_lds_windows[number{}]); + }); + + store_tile(k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}], + tile_elementwise_in(k_element_func, k_tile)); + + // prefetch first v_tile + v_tiles[I0] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + + block_sync_lds(); + gemm_0(s_acc, + get_slice_tile(q_tile, + sequence<0, (k0_loops - 1) * kK0>{}, + sequence{}), + k_lds_windows[number<(k0_loops - 1) % NumKLdsBuffers>{}]); + + __builtin_amdgcn_sched_barrier(0); + + const auto bias_tile = load_tile(bias_dram_window); // load bias tile + + static_for<1, NumPrefetchV, 1>{}([&](auto i_buf) { + v_tiles[i_buf] = load_tile(v_dram_window); + move_tile_window(v_dram_window, {0, kK1}); + }); + + // STAGE 2, scale_s, add bias, mask, siLU + if constexpr(kHasBias) + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + tile_elementwise_inout( + [&](auto& x, const auto& y) { + x += type_convert(bias_element_func(y)); + }, + s_acc, + bias_tile); + } + else + { + s_acc = tile_elementwise_in(s_acc_element_func, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); + } + + move_tile_window(bias_dram_window, {0, kN0}); + if constexpr(HstuMask::IsMasking) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsTokenPairInsideMask(row, col); + }); + } + else if constexpr(kPadSeqLenK) + { + const auto k_origin = k_dram_block_window.get_window_origin(); + set_tile_if(s_acc, -numeric::infinity(), [&](auto tile_idx) { + if(i_loop < num_loops) + return false; + const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); + const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); + return mask.IsTokenPairInsideMask(row, col); + }); + }; + + auto s = cast_tile(s_acc); // S{j} + + s = tile_elementwise_in(f_silu, s); + + if constexpr(kHasDropout) + { + auto randval_lds_ptr = + reinterpret_cast(smem_ptr) + Policy::template GetSmemSizeK(); + + dropout.template Run( + randval_lds_ptr, seqlen_k_start + i_loop * kN0, s, null_randval_window); + } + + __builtin_amdgcn_sched_barrier(0x7f); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[I0]); + + store_tile( + v_lds_windows[I0], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + } + else + { + store_tile(v_lds_windows[I0], + tile_elementwise_in(v_element_func, v_tiles[I0])); // store the prefetch + } + + __builtin_amdgcn_sched_barrier(0); + + const auto p = cast_tile(tile_elementwise_in(p_compute_element_func, s)); + + // STAGE 3, KV gemm + if constexpr(k1_loops > 1) + { + if constexpr(NumPrefetchV == 1) // NumVLdsBuffers == 2 + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + v_tiles[I0] = load_tile(v_dram_window); + + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, v_tiles[I0]); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); + } + else + { + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_tiles[I0])); + } + + move_tile_window(v_dram_window, {0, kK1}); + }); + } + else // NumVLdsBuffers == 3 or 2 + { + static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { + if constexpr(i_k1 < k1_loops - NumPrefetchV) + v_tiles[number{}] = load_tile(v_dram_window); + + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile( + p, sequence<0, i_k1 * kK1>{}, sequence{}), + v_lds_windows[number{}]); + + if constexpr(std::is_same_v) + { + auto v_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledVRegBlockDescriptor()); + shuffle_tile(v_shuffle_tmp, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}]); + store_tile(v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, v_shuffle_tmp)); + } + else + { + store_tile( + v_lds_windows[number<(i_k1 + 1) % NumVLdsBuffers>{}], + tile_elementwise_in(v_element_func, + v_tiles[number<(i_k1 + 1) % NumPrefetchV>{}])); + } + + if constexpr(i_k1 < k1_loops - NumPrefetchV) + move_tile_window(v_dram_window, {0, kK1}); + }); + } + } + // move K tile windows + move_tile_window(k_dram_block_window, {kN0, 0}); + + block_sync_lds(); + gemm_1(o_acc, + get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), + v_lds_windows[number<(k1_loops - 1) % NumVLdsBuffers>{}]); + + if constexpr(Policy::template IsFirstKLdsBufferOverlapLastVLdsBuffer()) + { + __builtin_amdgcn_sched_barrier(0); + __builtin_amdgcn_s_barrier(); + }; + } while(++i_loop < num_loops); + + o_acc = tile_elementwise_in(o_acc_element_func, o_acc); + + return o_acc; + } + + template + CK_TILE_HOST_DEVICE auto + operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile + const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile + const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile + const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile + HstuMask mask, + float scale_s, + void* smem_ptr, + DropoutType& dropout) const + { + return operator()(q_dram_block_window_tmp, + identity{}, + k_dram_block_window_tmp, + identity{}, + v_dram_block_window_tmp, + identity{}, + bias_dram_block_window_tmp, + identity{}, + identity{}, + identity{}, + identity{}, + mask, + scale_s, + smem_ptr, + dropout); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..d2ededb305 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_pipeline_default_policy.hpp @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" + +namespace ck_tile { + +struct HstuAttentionFwdPipelineQRKSVSDefaultPolicy + : BlockFmhaPipelineQXKSVSCustomPolicy +{ + static constexpr index_t NumPrefetchV = 2; + + template + CK_TILE_DEVICE static constexpr auto GetNumKLdsBuffers() + { + return 2; + } + + template + CK_TILE_DEVICE static constexpr auto GetNumPrefetchV() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t kN0 = BlockFmhaShape::kN0; + constexpr index_t kK1 = BlockFmhaShape::kK1; + + constexpr index_t k1_loops = kN0 / kK1; + + return min(NumPrefetchV, k1_loops); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetNumVLdsBuffers() + { + return 2; + }; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() + { + using BlockGemm = remove_cvref_t())>; + + return BlockGemm::template MakeABlockTileDistribution< + Problem::BlockFmhaShape::kM0, + Problem::BlockFmhaShape::kQKHeaddim>(); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() + { + using QKVDataType = remove_cvref_t; + return 8 / sizeof(QKVDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() + { + constexpr index_t NumKLdsBuffers = GetNumKLdsBuffers(); + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPack = GetSmemKPackK(); + constexpr index_t kKVector = GetAlignmentK(); + + static_assert(kKVector % kKPack == 0); + + constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto k_lds_block_desc = transform_tensor_descriptor( + k_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform(make_tuple(number{}, + number{}, + number{}))), + make_tuple(sequence<0, 3>{}, sequence<1, 2, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return k_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() + { + using QKVDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t MaxVectorSize = 16 / sizeof(QKVDataType); + + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + static_assert(0 < ElemPerThread); + constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); + + constexpr index_t KPerThread = kMaxVecLoad; + constexpr index_t KThreads = kKPerBlock / KPerThread; + constexpr index_t NThreadPerWarp = get_warp_size() / KThreads; + constexpr index_t NumWarps = kBlockSize / get_warp_size(); + constexpr index_t NPerThread = kNPerBlock / (NThreadPerWarp * NumWarps); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<1, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() + { + using QKVDataType = remove_cvref_t; + + constexpr index_t NumVLdsBuffers = GetNumVLdsBuffers(); + + constexpr index_t Banks = 32; // TODO: need change based on arch + constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QKVDataType); + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(PixelsPerRow % kKPack == 0); + constexpr index_t NPerRow = PixelsPerRow / kKPack; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + static_assert(kNPerBlock % NPerRow == 0); + static_assert(kKPerBlock % kKPack == 0); + + constexpr index_t VSingleSmemElementSpaceSize = + (kKPerBlock / kKPack) * (kNPerBlock / NPerRow) * (PixelsPerRow + kKPack); + + constexpr auto v_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto v_lds_block_desc = transform_tensor_descriptor( + v_lds_block_desc_0, + make_tuple( + make_merge_transform(make_tuple( + number{}, number{}, number{})), + make_merge_transform(make_tuple(number{}, number{}))), + make_tuple(sequence<0, 2, 3>{}, sequence<1, 4>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return v_lds_block_desc; + } + + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution() + { + using VLayout = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = GetAlignmentV(); + constexpr index_t N0 = kNPerBlock / N1; // P + + constexpr index_t ElemPerThread = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(ElemPerThread % N1 == 0); + constexpr index_t K3 = ElemPerThread / N1; + constexpr index_t kKPack = GetSmemKPackV(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = GetAlignmentV(); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + constexpr index_t N1 = kBlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + static_assert(N0 != 0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm0BlockWarps, + typename Problem::BlockFmhaShape::Gemm0WarpTile>>; + + constexpr auto warp_gemm = []() { + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + + if constexpr(std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaF16F16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaBf16Bf16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 32); + + // TODO: hard coded here. Otherwise, it may incorrect result + constexpr index_t swizzle_factor = 4; + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution< + swizzle_factor>{}; + } // TODO - bf8_t + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + + if constexpr(1 < Problem::kNumGemm0Warps) + return BlockGemmARegBSmemCRegV2{}; + else + return BlockGemmARegBSmemCRegOneWarpV1{}; + } + + // leave some exclusive space so that the second v_lds buffer will nenver overlap with the first + // k_lds bufffer + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetExclusiveKLdsBytes() + { + constexpr index_t single_k_lds_buffer_size = + GetSmemSizeK() / GetNumKLdsBuffers(); + constexpr index_t single_v_lds_buffer_size = + GetSmemSizeV() / GetNumVLdsBuffers(); + + if constexpr(single_k_lds_buffer_size <= single_v_lds_buffer_size) + return 0; + else + return integer_least_multiple(single_k_lds_buffer_size - single_v_lds_buffer_size, 64); + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t IsFirstKLdsBufferOverlapLastVLdsBuffer() + { + using BlockFmhaShape = remove_cvref_t; + + constexpr index_t k1_loops = BlockFmhaShape::kN0 / BlockFmhaShape::kK1; + constexpr index_t num_k_lds_buffers = GetNumKLdsBuffers(); + constexpr index_t num_v_lds_buffers = GetNumVLdsBuffers(); + + constexpr index_t last_v_lds_buffer_offset = + MakeVLdsBlockDescriptor().get_element_space_size() / num_v_lds_buffers * + ((k1_loops - 1) % num_v_lds_buffers) * sizeof(typename Problem::VDataType); + + constexpr index_t first_k_lds_buffer_size = + MakeKLdsBlockDescriptor().get_element_space_size() / num_k_lds_buffers * + sizeof(typename Problem::QKVDataType); + + return GetExclusiveKLdsBytes() + last_v_lds_buffer_offset < + first_k_lds_buffer_size; + }; + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() + { + return MakeKLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QKVDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() + { + return MakeVLdsBlockDescriptor().get_element_space_size() * + sizeof(typename Problem::QKVDataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() + { + // assume V can reuse the other shared memory by K except the first + // assume Dropout can reuse the shared memory by V + return GetExclusiveKLdsBytes() + + max(GetSmemSizeK() - GetExclusiveKLdsBytes(), + max(GetSmemSizeV(), GetSmemSizeDropout(0))); + } +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp new file mode 100644 index 0000000000..c842df341b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_setting.hpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include + +#include "hstu_attention_fwd_type_config.hpp" + +template +struct HstuAttentionFwdBlockTile; + +// Tile-sizes: M N0 K0 N1 K1 MaxK (MaxK % K0 == 0, MaxK % N1 == 0, N0 % K1 == 0) +// +template <> +struct HstuAttentionFwdBlockTile<32> +{ + using type = ck_tile::sequence<64, 64, 16, 32, 32, 32>; + using gemm0_warps = ck_tile::sequence<2, 1, 1>; + using gemm1_warps = ck_tile::sequence<2, 1, 1>; +}; + +template <> +struct HstuAttentionFwdBlockTile<64> +{ + using type = ck_tile::sequence<128, 64, 32, 64, 32, 64>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionFwdBlockTile<128> +{ + using type = ck_tile::sequence<128, 128, 32, 128, 32, 128>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +template <> +struct HstuAttentionFwdBlockTile<256> +{ + using type = ck_tile::sequence<128, 128, 32, 256, 32, 256>; + using gemm0_warps = ck_tile::sequence<4, 1, 1>; + using gemm1_warps = ck_tile::sequence<4, 1, 1>; +}; + +using HstuAttentionFwdWarpTile1 = ck_tile::sequence<32, 32, 16>; + +template +struct HstuAttentionFwdShape; + +template <> +struct HstuAttentionFwdShape<32> +{ + using Type = ck_tile::TileFmhaShape::type, + typename HstuAttentionFwdBlockTile<32>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<32>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct HstuAttentionFwdShape<64> +{ + using Type = ck_tile::TileFmhaShape::type, + typename HstuAttentionFwdBlockTile<64>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<64>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct HstuAttentionFwdShape<128> +{ + using Type = ck_tile::TileFmhaShape::type, + typename HstuAttentionFwdBlockTile<128>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<128>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; +}; + +template <> +struct HstuAttentionFwdShape<256> +{ + using Type = ck_tile::TileFmhaShape::type, + typename HstuAttentionFwdBlockTile<256>::gemm0_warps, + HstuAttentionFwdWarpTile1, + typename HstuAttentionFwdBlockTile<256>::gemm1_warps, + HstuAttentionFwdWarpTile1, + IsVLayoutRowMajor>; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp new file mode 100644 index 0000000000..4afc9e14d2 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_fwd_type_config.hpp @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include + +template +struct HstuAttentionFwdTypeConfig; + +template <> +struct HstuAttentionFwdTypeConfig +{ + using BiasDataType = ck_tile::fp16_t; + using GemmAccDataType = float; + using CompDataType = float; // data type for non-linear calculation + using OaccDataType = GemmAccDataType; + using ODataType = ck_tile::fp16_t; +}; + +template <> +struct HstuAttentionFwdTypeConfig +{ + using BiasDataType = ck_tile::bf16_t; + using GemmAccDataType = float; + using CompDataType = float; // data type for non-linear calculation + using OaccDataType = GemmAccDataType; + using ODataType = ck_tile::bf16_t; +}; + +static constexpr bool IsVLayoutRowMajor = true; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_hdim_switch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_hdim_switch.hpp new file mode 100644 index 0000000000..ab8cbe31ed --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_hdim_switch.hpp @@ -0,0 +1,30 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include + +#define HDIM_SWITCH(HDIM_1, HDIM_2, CONST_NAME, ...) \ + [&] { \ + if(HDIM_1 <= 64 && HDIM_2 <= 64) \ + { \ + constexpr ck_tile::index_t CONST_NAME = 64; \ + __VA_ARGS__(); \ + } \ + else if(HDIM_1 <= 128 && HDIM_2 <= 128) \ + { \ + constexpr ck_tile::index_t CONST_NAME = 128; \ + __VA_ARGS__(); \ + } \ + else if(HDIM_1 <= 256 && HDIM_2 <= 256) \ + { \ + constexpr ck_tile::index_t CONST_NAME = 256; \ + __VA_ARGS__(); \ + } \ + else \ + { \ + throw std::runtime_error("Head-dim sizes not supported!"); \ + } \ + }() diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp new file mode 100644 index 0000000000..12d88238cf --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_bf16.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_jagged_forward_dispatch.hpp" + +#include "instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp" + +void hstu_attention_jagged_forward_bf16(HstuAttentionFwdParams& param, hipStream_t stream) +{ + const bool has_dropout = (param.p_drop > 0.0f); + const bool has_bias = (param.bias_ptr != nullptr); + const bool use_causal = param.use_causal; + BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { + HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { + if(param.window_size > 0) + { + run_jagged_forward_causal_local_bias_dropout_dispatch(param, stream); + } + else + { + run_jagged_forward_causal_local_bias_dropout_dispatch(param, stream); + }; + }); + }); +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp new file mode 100644 index 0000000000..5fe497d666 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_dispatch.hpp @@ -0,0 +1,144 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#pragma once + +#include +#include +#include +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_fwd_type_config.hpp" +#include "hstu_attention_fwd_setting.hpp" +#include "hstu_attention_params.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_block_masking.hpp" +#include "hstu_attention_pipeline_problem.hpp" +#include "hstu_attention_traits.hpp" +#include "hstu_attention_fwd_pipeline.hpp" +#include "hstu_attention_fwd_kernel.hpp" + +template +struct jagged_forward_causal_local_bias_dropout_dispatch +{ + using HstuAttentionShape = typename HstuAttentionFwdShape::Type; + using HstuMask = ck_tile::HstuBlockMasking; + + template + using HstuPipelineProblemTemp = ck_tile::HstuAttentionFwdPipelineProblem< + InOutDataType, + typename HstuAttentionFwdTypeConfig::GemmAccDataType, + typename HstuAttentionFwdTypeConfig::CompDataType, + typename HstuAttentionFwdTypeConfig::BiasDataType, + true, // kIsJagged + kHasBias, + kHasDropout, + HstuMask, + HstuAttentionShape, + HstuTraits>; + + static void Run(HstuAttentionFwdParams& param, hipStream_t stream) + { + constexpr ck_tile::index_t occupancy = -1; + + const bool pad_headdim_qk = !(param.hdim_qk % HstuAttentionShape::kSubQKHeaddim == 0); + const bool pad_headdim_v = !(param.hdim_v % HstuAttentionShape::kN1 == 0); + + // no need to check seqlen_q since it is not used as fastest dim, + // buffer_load_dwordxx/buffer_store_dwordxx can handle oob access + constexpr bool kPadSeqLenQ = false; + + constexpr bool kPadSeqLenK = true; + + BOOL_SWITCH_2(pad_headdim_qk, kPadHeadDimQK, pad_headdim_v, kPadHeadDimV, [&] { + using HstuTraits = ck_tile::HstuAttentionFwdTraits; + + using HstuPipelineProblem = HstuPipelineProblemTemp; + + using HstuEpilogue = ck_tile::Default2DEpilogue::OaccDataType, + typename HstuAttentionFwdTypeConfig::ODataType, + kPadSeqLenQ, + kPadHeadDimV>>; + + using HstuPipeline = ck_tile::HstuAttentionFwdPipelineQRKSVS; + using HstuKernel = ck_tile::HstuAttentionFwdKernel; + + RunWithKernel(param, stream); + }); + }; + + template + static void RunWithKernel(HstuAttentionFwdParams& param, hipStream_t stream) + { + const auto kargs = [&] { + return HstuKernel::MakeKargs(param.q_ptr, + param.k_ptr, + param.v_ptr, + param.bias_ptr, + param.o_ptr, + param.seq_offsets_ptr, + param.hdim_qk, + param.hdim_v, + param.num_head, + param.scale_s, + param.seq_stride_q, + param.seq_stride_k, + param.seq_stride_v, + param.seq_stride_bias, + param.seq_stride_o, + param.nhead_stride_q, + param.nhead_stride_k, + param.nhead_stride_v, + param.nhead_stride_bias, + param.nhead_stride_o, + param.num_targets_ptr, + param.window_size, + param.contextual_seqlen, + param.min_full_attn_seqlen, + param.p_drop, + param.philox_seed, + param.philox_offset); + }(); + + dim3 kGridSize = + HstuKernel::GridSize(param.num_batch, param.num_head, param.seqlen, param.hdim_v); + constexpr dim3 kBlockSize = HstuKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = HstuKernel::kBlockPerCu; + + (void)ck_tile::launch_kernel(ck_tile::stream_config{stream, false}, + ck_tile::make_kernel( + HstuKernel{}, kGridSize, kBlockSize, 0, kargs)); + }; +}; + +template +void run_jagged_forward_causal_local_bias_dropout_dispatch(HstuAttentionFwdParams& param, + hipStream_t stream) +{ + jagged_forward_causal_local_bias_dropout_dispatch::Run(param, stream); +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp new file mode 100644 index 0000000000..c35ddf51b8 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_jagged_forward_fp16.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#include +#include + +#include "hstu_attention_bool_switch.hpp" +#include "hstu_attention_hdim_switch.hpp" +#include "hstu_attention_jagged_forward_dispatch.hpp" + +#include "instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp" + +void hstu_attention_jagged_forward_fp16(HstuAttentionFwdParams& param, hipStream_t stream) +{ + const bool has_dropout = (param.p_drop > 0.0f); + const bool has_bias = (param.bias_ptr != nullptr); + const bool use_causal = param.use_causal; + BOOL_SWITCH_3(has_bias, kHasBias, has_dropout, kHasDropout, use_causal, kUseCausal, [&] { + HDIM_SWITCH(param.hdim_qk, param.hdim_v, MaxK, [&] { + if(param.window_size > 0) + { + run_jagged_forward_causal_local_bias_dropout_dispatch(param, stream); + } + else + { + run_jagged_forward_causal_local_bias_dropout_dispatch(param, stream); + }; + }); + }); +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp new file mode 100644 index 0000000000..d76c8857e4 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_params.hpp @@ -0,0 +1,57 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +struct HstuAttentionFwdParams +{ + bool is_jagged; + + ck_tile::index_t num_batch; + ck_tile::index_t seqlen; // batched mode only + const void* seq_offsets_ptr; // jagged mode only + ck_tile::index_t max_seqlen; // jagged mode only + + const void* q_ptr; + const void* k_ptr; + const void* v_ptr; + const void* bias_ptr; + void* o_ptr; + + ck_tile::index_t hdim_qk; + ck_tile::index_t hdim_v; + ck_tile::index_t num_head; + float scale_s; + + ck_tile::index_t seq_stride_q; + ck_tile::index_t seq_stride_k; + ck_tile::index_t seq_stride_v; + ck_tile::index_t seq_stride_bias; + ck_tile::index_t seq_stride_o; + + ck_tile::index_t nhead_stride_q; + ck_tile::index_t nhead_stride_k; + ck_tile::index_t nhead_stride_v; + ck_tile::index_t nhead_stride_bias; + ck_tile::index_t nhead_stride_o; + + // batched mode only parameters + ck_tile::index_t batch_stride_q; + ck_tile::index_t batch_stride_k; + ck_tile::index_t batch_stride_v; + ck_tile::index_t batch_stride_bias; + ck_tile::index_t batch_stride_o; + + const void* num_targets_ptr; + + bool use_causal; + ck_tile::index_t window_size; + ck_tile::index_t contextual_seqlen; + ck_tile::index_t min_full_attn_seqlen; + + float p_drop; + uint64_t philox_seed; + uint64_t philox_offset; +}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp new file mode 100644 index 0000000000..bec43a83a2 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_pipeline_problem.hpp @@ -0,0 +1,62 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +namespace ck_tile { + +// PipelineProblem encodes information not only from the original user-problem, +// but it also contains other information needed by the pipeline, which includes +// TileShape -- which determines how block-layer calculation is done in tiles and +// how warps are allocated on dimensions +// Traits -- other information required for running the kernel and pipeline + +template +struct HstuAttentionFwdPipelineProblem +{ + using InOutDataType = remove_cvref_t; + using QKVDataType = InOutDataType; + using ODataType = InOutDataType; + using GemmAccDataType = remove_cvref_t; + + // DataType used when siLU calculation + using CompDataType = remove_cvref_t; + using BiasDataType = remove_cvref_t; + + // to be compatible with ck_tile existing policy codes + using QDataType = QKVDataType; + using KDataType = QKVDataType; + using VDataType = QKVDataType; + using SaccDataType = GemmAccDataType; + using OaccDataType = GemmAccDataType; + using PDataType = QKVDataType; + + static constexpr bool kIsJagged = kIsJagged_; + static constexpr bool kHasBias = kHasBias_; + static constexpr bool kHasDropout = kHasDropout_; + + using HstuMask = remove_cvref_t; + + using HstuAttentionTileShape = remove_cvref_t; + + // Keep the name compatible with ck_tile existing policy codes, to be changed + using BlockFmhaShape = HstuAttentionTileShape; + using Traits = remove_cvref_t; + + static constexpr index_t kNumGemm0Warps = AttentionTileShape_::NumGemm0Warps; + static constexpr index_t kNumGemm1Warps = AttentionTileShape_::NumGemm1Warps; + static constexpr index_t kBlockSize = AttentionTileShape_::NumWarps * get_warp_size(); +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp deleted file mode 100644 index 3225d05651..0000000000 --- a/example/ck_tile/18_hstu_attention/hstu_attention_setting.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -// Type configuration -template -struct HSTUAttentionTypeConfig; - -template <> -struct HSTUAttentionTypeConfig -{ - using GemmAccDataType = float; - using SMComputeDataType = float; -}; - -template <> -struct HSTUAttentionTypeConfig -{ - using GemmAccDataType = float; - using SMComputeDataType = float; -}; diff --git a/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp b/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp new file mode 100644 index 0000000000..184b5352db --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_attention_traits.hpp @@ -0,0 +1,25 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct HstuAttentionFwdTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQK = kPadHeadDimQK_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp new file mode 100644 index 0000000000..005b9b29b1 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/hstu_block_masking.hpp @@ -0,0 +1,109 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct HstuBlockMasking +{ + static constexpr bool IsMasking = (kUseCausal || kUseLocal); + + int max_attn_len; + int contextual_seqlen; + int min_full_attn_seqlen; + int max_uih_len; + + CK_TILE_HOST_DEVICE HstuBlockMasking(int max_attn_len_, + int contextual_seqlen_, + int min_full_attn_seqlen_, + int max_uih_len_) + { + max_attn_len = max_attn_len_; + contextual_seqlen = contextual_seqlen_; + min_full_attn_seqlen = min_full_attn_seqlen_; + max_uih_len = max_uih_len_; + }; + + // to get the loop length along X axis, return index:[start, end), end-start=length + // use this if need loop over X axis tile by tile (eg. seqlen_k loop-over) + template + CK_TILE_HOST_DEVICE constexpr auto + GetTileRangeAlongX(index_t i_y, number, number) const + { + if constexpr(!IsMasking) + { + return ck_tile::make_tuple(0, max_uih_len); + } + else + { + if(contextual_seqlen > 0 && (i_y < contextual_seqlen)) + return ck_tile::make_tuple(0, max_uih_len); + + if constexpr(kUseCausal && !kUseLocal) + { + index_t x_end = + min(i_y + YTile, max_uih_len); // for lower-triangular masking, x <= y + + return ck_tile::make_tuple(0, x_end); + } + else if constexpr(!kUseCausal && kUseLocal) + { + if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) + { + return ck_tile::make_tuple(0, max_uih_len); + } + else + { + index_t x_start = max(0, i_y - max_attn_len); + index_t x_end = i_y + YTile + max_attn_len; + + return ck_tile::make_tuple(x_start - x_start % XTile, x_end); + }; + } + else // kUseCausal && kUseLocal + { + if(min_full_attn_seqlen > 0 && i_y + YTile > max_uih_len - min_full_attn_seqlen) + { + return ck_tile::make_tuple(0, max_uih_len); + } + else + { + index_t x_end = i_y + YTile + max_attn_len; + + return ck_tile::make_tuple(0, x_end); + }; + }; + }; + } + + CK_TILE_HOST_DEVICE constexpr bool IsTokenPairInsideMask(int row, int col) + { + if(row < contextual_seqlen) + return true; + + bool result = false; + if constexpr(kUseLocal) + { + if constexpr(kUseCausal) + result = (row >= col) && (row - col <= max_attn_len); + else + result = std::abs(row - col) <= max_attn_len; + + if(min_full_attn_seqlen > 0) + result = result || (row >= max_uih_len - min_full_attn_seqlen); + } + else + { + if constexpr(kUseCausal) + result = (row >= col); + }; + + return result; + }; +}; + +} // namespace ck_tile diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..135653e3ee --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..1240e37c28 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..dc61eb8ce7 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..b6fd75d101 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..fecc66f9ef --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..ee6bc64f14 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..58f5c2837b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..e567c295f4 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..4a9fdb3052 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..6610b3d37d --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..a75ff8e76a --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..a4f6adc2eb --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_instances_ref.hpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_instances_ref.hpp new file mode 100644 index 0000000000..c4a377f9ed --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_instances_ref.hpp @@ -0,0 +1,206 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..a222471687 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..f723f0fb38 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..09be671b6c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..7fc08fc7f0 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..6b2de8cc19 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..fe1e8b3509 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..602de60114 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..fd41687481 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..89160eb715 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..cdef3391a1 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..1f4508cd6b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..dd0c7978f2 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..48fb53b9e8 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..2bba03d1e9 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..a10e926a29 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..910b50a93f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..dda7dc07aa --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..ce48b9a401 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..eac83dc8fe --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..6e700fce4b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..5c0206156c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..3cf5941e8e --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..f2f1bdc3cd --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..cd9979a4ee --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_instances_ref.hpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_instances_ref.hpp new file mode 100644 index 0000000000..5c350254dc --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_instances_ref.hpp @@ -0,0 +1,206 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..64b88f8fe2 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..d399cf6dbb --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..91875c4deb --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..3795aa9585 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..653c378852 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..c926f2e9b7 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..859ff63363 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..2d60c18f02 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..bbf5f6817e --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..e3329878ce --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..411afb0ec1 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..49bb0b3ce4 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_batched_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_batched_forward_dispatch.hpp" + +template void run_batched_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..911485ab20 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..c1e80c9666 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..795c8cc4ed --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..7db76cde63 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..0f3ee275fb --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..c875587af3 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..5a8846cc56 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..8e73023e0c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..71050ae3d7 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..cd165a1861 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..bf40833a7f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..cef802a33e --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp new file mode 100644 index 0000000000..0bdbdd91fc --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_instances_ref.hpp @@ -0,0 +1,206 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..3507c1f29c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..4f330cd73b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..22ca2c84ad --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..3b95aeb61b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..6a06c0e0d3 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..3bbbe402d2 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..dffc329c21 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..8337a1559f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..3c09e8416f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..88abedfebf --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..ab45759547 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..2843a7a1c5 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_bf16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::bf16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..e0a29981a8 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..10b05c8a77 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..83d43f43ef --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..459cb224c0 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..5093f4742f --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..60440c4916 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..9be001a7e0 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..1d55f69805 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..ea8e8895e9 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..0495ee907a --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..88464a5075 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..5be82db604 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_has_causal_has_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp new file mode 100644 index 0000000000..c4ae5f851b --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_instances_ref.hpp @@ -0,0 +1,206 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + true, + true, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); + +extern template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..6d8459a492 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..23de813bbe --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..8aead96641 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..f8ac16a847 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..a36b542f6d --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..59d20c38e7 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_has_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + true, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp new file mode 100644 index 0000000000..54bf3b1674 --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp new file mode 100644 index 0000000000..c78ee824fb --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp new file mode 100644 index 0000000000..195fa6cc0a --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_has_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + true, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp new file mode 100644 index 0000000000..6df4d9713d --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_128.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 128>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp new file mode 100644 index 0000000000..0a7a51034d --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_256.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 256>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp new file mode 100644 index 0000000000..31fbe9459c --- /dev/null +++ b/example/ck_tile/18_hstu_attention/instances/hstu_attention_jagged_forward_fp16_no_causal_no_causal_no_bias_no_dropout_maxk_64.cpp @@ -0,0 +1,22 @@ + +/* + Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + * + * The file is automatically generated, don't modify! + * See the generator script + * `composable_kernel/example/ck_tile/18_hstu_attention/generate_instances.py` + */ + +#include +#include "hstu_attention_jagged_forward_dispatch.hpp" + +template void run_jagged_forward_causal_local_bias_dropout_dispatch< + ck_tile::fp16_t, + false, + false, + false, + false, + 64>(HstuAttentionFwdParams& param, hipStream_t stream); diff --git a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp index 27cd729a66..1651e546d5 100644 --- a/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp +++ b/example/ck_tile/18_hstu_attention/reference_hstu_attention.hpp @@ -11,16 +11,17 @@ #include #include -#include "bool_switch.hpp" +#include "hstu_attention_bool_switch.hpp" +#include "hstu_block_masking.hpp" namespace ck_tile { // clang-format off // Reference implementation of HSTUAttention problem, which does the following from input tensors: // S[num_batch, num_head, seqlen, seqlen] = Q[num_batch, seqlen, num_head, hdim_qk] @ key^T[num_batch, seqlen, num_head, hdim_v] -// P[num_batch, num_head, seqlen, seqlen] = Masking(SiLu(S[num_batch, num_head, seqlen, seqlen])) +// P[num_batch, num_head, seqlen, seqlen] = SiLU(Masking(S[num_batch, num_head, seqlen, seqlen])) // O[num_batch, num_head, seqlen, hdim_v] = P[num_batch, num_head, seqlen, seqlen] @ value^T[num_batch, num_head, seqlen, hdim_v] -// The process is very similar to the generic attention, the difference is that SiLu is used rather than Softmax, and hstu masking +// The process is very similar to the generic attention, the difference is that SiLU is used rather than Softmax, and hstu masking // is much more complicated than the lower-triangular + disagonal-window based causal mask // clang-format on @@ -32,50 +33,6 @@ template struct reference_hstu_attention { - struct hstu_mask - { - int max_attn_len; - int contextual_seq_len; - int min_full_attn_seq_len; - int max_uih_len; - - hstu_mask(int max_attn_len_, - int contextual_seq_len_, - int min_full_attn_seq_len_, - int max_uih_len_) - { - max_attn_len = max_attn_len_; - contextual_seq_len = contextual_seq_len_; - min_full_attn_seq_len = min_full_attn_seq_len_; - max_uih_len = max_uih_len_; - }; - - bool IsTokenPairInsideMask(int row, int col) - { - if(row < contextual_seq_len) - return true; - - bool result = false; - if constexpr(kUseLocal) - { - if constexpr(kUseCausal) - result = (row >= col) && (row - col <= max_attn_len); - else - result = std::abs(row - col) <= max_attn_len; - - if(min_full_attn_seq_len > 0) - result = result || (row >= max_uih_len - min_full_attn_seq_len); - } - else - { - if constexpr(kUseCausal) - result = (row >= col); - }; - - return result; - }; - }; - static void Run(const HostTensor& q_batch_seq_nhead_hdim, const HostTensor& k_batch_seq_nhead_hdim, const HostTensor& v_batch_seq_nhead_hdim, @@ -86,10 +43,10 @@ struct reference_hstu_attention std::vector num_targets, // define masking length at the end of token // sequence to be excluded for attention int max_attn_len, // define the diagonal local window size - int contextual_seq_len, // define masking length at the begin of query token - // sequence to be included for attention - int min_full_attn_seq_len) // define masking length at the end of query token - // sequence which is included for full attention + int contextual_seqlen, // define masking length at the begin of query token + // sequence to be included for attention + int min_full_attn_seqlen) // define masking length at the end of query token + // sequence which is included for full attention { if constexpr(kIsJagged) { @@ -145,13 +102,14 @@ struct reference_hstu_attention int max_uih_len = seqlen; - if(contextual_seq_len > 0) - max_uih_len -= contextual_seq_len - 1; + if(contextual_seqlen > 0) + max_uih_len -= contextual_seqlen - 1; if(has_target) max_uih_len -= num_targets[i_batch]; - hstu_mask mask{max_attn_len, contextual_seq_len, min_full_attn_seq_len, max_uih_len}; + HstuBlockMasking mask{ + max_attn_len, contextual_seqlen, min_full_attn_seqlen, max_uih_len}; // for all rows in the batch for(int sq = 0; sq < max_uih_len; sq++)