mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-01 12:11:19 +00:00
* enable attn sink Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update attn_sink script Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix some error Signed-off-by: JL-underdog <Jun.Lin@amd.com> * clang-format Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update fmha_bwd_kernel'mask Signed-off-by: JL-underdog <Jun.Lin@amd.com> * update block_fmha_pipeline_qr_ks_vs.hpp Signed-off-by: JL-underdog <Jun.Lin@amd.com> * fix ci error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix format error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_bwd_pipeline_default_policy.hpp * Update fmha_fwd_runner.hpp * Update block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * Update fmha_fwd_runner.hpp * update splitkv_pipline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update splitkv&pagedkv pipeline Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * add sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update attn_sink result log Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update smoke_test_fwd_sink.sh Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test file Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * update test script Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp * use constexpr kHasSink for sink in fmha pipeline Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * update by pre-commit Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update fmha_fwd.py * Update example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs.hpp Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove causal mask setting logic from mask.hpp Removed the mask setting logic for causal masks. * fix ci error that some usage of lamada not support in c++17 Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update remod.py * add smoke sink test Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update fmha_pagedkv_prefill.py * Update FmhaFwdPipeline parameters in fmha_fwd.py * update block_fmha_pipeline_qr_ks_vs_async_trload.hpp Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * fix c++17 unsupprot error Signed-off-by: LJ-underdog <Jun.Lin@amd.com> * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp * Fix formatting of sink_seq_end assignment * Fix indentation for sink_seq_end assignment * Update block_fmha_fwd_pagedkv_pipeline_qr_ks_vs.hpp --------- Signed-off-by: JL-underdog <Jun.Lin@amd.com> Signed-off-by: LJ-underdog <Jun.Lin@amd.com> Signed-off-by: Linjun-AMD <Jun.Lin@amd.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
33 lines
943 B
C++
33 lines
943 B
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck_tile/core.hpp"
|
|
#include "ck_tile/host/host_tensor.hpp"
|
|
#include <thread>
|
|
|
|
namespace ck_tile {
|
|
|
|
template <typename CDataType, typename MaskingType>
|
|
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
|
|
{
|
|
const int M = c_b_m_n.mDesc.get_lengths()[1];
|
|
const int N = c_b_m_n.mDesc.get_lengths()[2];
|
|
|
|
auto f = [&](auto batch) {
|
|
for(int n = 0; n < N; ++n)
|
|
{
|
|
for(int m = 0; m < M; ++m)
|
|
{
|
|
if(mask.IsOutOfSinkBound(m, n))
|
|
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
|
|
}
|
|
}
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f,
|
|
c_b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
|
|
}
|
|
} // namespace ck_tile
|