mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* Re-mapping thread block indices for causal=True kernels * Use more intuitive remap_opt value * Fallback to origin remapping if seqlen_q >= 64K * Use GenericAttentionMask to reduce mask computation * Avoid unnecessary boundary check for IsMasking=false case * Fix wrong kernel entry specifier * Add s_nop to prevent delay wave0-3 * Refine scheduling * Remove unnecessary sched_group_barrier() * Move sched_group_barrier() call to scheduler * Replace inline asm s_setprio with intrinsics * Rephrase comments * Expend some o_acc rescaling insts to avoid SIMD idle * Fix block idx special mapping logic * Tune block index mapping for causal=False cases * Tune block index mapping for causal=True cases * Fix wrong vmcnt() * Remove parameter name * Use boolean option for turn on/off causal mask * Update benchmark_fwd_v3.sh option usages * Add option if compiler support it
69 lines
1.6 KiB
C++
69 lines
1.6 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include <iostream>
|
|
#include <utility>
|
|
|
|
#include "ck_tile/core/numeric/integer.hpp"
|
|
#include "ck_tile/host/stream_config.hpp"
|
|
|
|
namespace ck_tile {
|
|
|
|
struct fmha_fwd_v3_args
|
|
{
|
|
enum class data_type_enum
|
|
{
|
|
fp16,
|
|
bf16
|
|
};
|
|
|
|
data_type_enum data_type;
|
|
// bool is_varlen;
|
|
|
|
index_t batch;
|
|
index_t seqlen_q;
|
|
index_t seqlen_k;
|
|
index_t nhead_q;
|
|
index_t nhead_kv;
|
|
index_t hdim_qk;
|
|
index_t hdim_v;
|
|
|
|
float softmax_scale;
|
|
|
|
index_t window_size_left;
|
|
index_t window_size_right;
|
|
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
|
|
// window_size_right == 0).
|
|
|
|
const void* q_ptr;
|
|
index_t stride_q;
|
|
index_t nhead_stride_q;
|
|
index_t batch_stride_q;
|
|
|
|
const void* k_ptr;
|
|
index_t stride_k;
|
|
index_t nhead_stride_k;
|
|
index_t batch_stride_k;
|
|
|
|
const void* v_ptr;
|
|
index_t stride_v;
|
|
index_t nhead_stride_v;
|
|
index_t batch_stride_v;
|
|
|
|
void* o_ptr;
|
|
index_t stride_o;
|
|
index_t nhead_stride_o;
|
|
index_t batch_stride_o;
|
|
};
|
|
|
|
std::ostream& operator<<(std::ostream& stream, const fmha_fwd_v3_args::data_type_enum& data_type);
|
|
|
|
// return value:
|
|
// first = whether the kernel was launched (true = launched, false = skipped)
|
|
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
|
std::pair<bool, float> fmha_fwd_v3(const fmha_fwd_v3_args& args, const stream_config& config);
|
|
|
|
} // namespace ck_tile
|