Files
composable_kernel/include/ck_tile/host/reference/reference_batched_masking.hpp
Linjun-AMD f5573f56d9 Add attention sink support for FMHA FWD (#3368)
* Revert "Revert "Add attn sink (#2892)" (#3250)"

This reverts commit 5adaa201ed.

* fix conflict

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Add F_sink parameter to FmhaFwdPipeline

* Update tile_fmha_traits.hpp

* Refactor pipeline creation in fmha_fwd.py

Updated the pipeline creation logic to include 'sink' parameter in product combinations and adjusted the FmhaFwdPipeline calls accordingly.

* Update fmha_fwd.py

* Update fmha_fwd.py

* Update example/ck_tile/01_fmha/script/correct_test_fwd_sink.sh

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* update CHANGELOG.md

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>

* Update CHANGELOG with new features and support

* Update fmha_fwd.hpp

* Update CHANGELOG.md

* Update smoke_test_fwd_sink.sh

* Update correct_test_fwd_sink.sh

* Update smoke_test_fwd_sink.sh

---------

Signed-off-by: Linjun-AMD <Jun.Lin@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-15 12:21:59 +08:00

33 lines
931 B
C++

// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename CDataType, typename MaskingType>
CK_TILE_HOST void reference_batched_masking(HostTensor<CDataType>& c_b_m_n, const MaskingType& mask)
{
const int M = c_b_m_n.mDesc.get_lengths()[1];
const int N = c_b_m_n.mDesc.get_lengths()[2];
auto f = [&](auto batch) {
for(int n = 0; n < N; ++n)
{
for(int m = 0; m < M; ++m)
{
if(mask.IsOutOfSinkBound(m, n))
c_b_m_n(batch, m, n) = -ck_tile::numeric<CDataType>::infinity();
}
}
};
make_ParallelTensorFunctor(f,
c_b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
}
} // namespace ck_tile