mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
fixing args
This commit is contained in:
@@ -20,22 +20,31 @@
|
||||
#include "fmha_fwd_v3.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_V3_DISPATCH(kernel_traits) \
|
||||
#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
std::pair<bool, float> fmha_fwd_v3_kernel_dispatch<kernel_traits>( \
|
||||
const fmha_fwd_v3_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair(true, \
|
||||
unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
|
||||
fmha_fwd_v3_kernel_launch<kernel_traits::kernel>(args, config)); \
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <unified_attention_args::data_type_enum DataType>
|
||||
struct unified_attention_problem_traits;
|
||||
template <fmha_fwd_v3_args::data_type_enum DataType>
|
||||
struct fmha_fwd_v3_problem_traits;
|
||||
|
||||
template <>
|
||||
struct unified_attention_problem_traits<unified_attention_args::data_type_enum::bf16>
|
||||
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::fp16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::half_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::half_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct fmha_fwd_v3_problem_traits<fmha_fwd_v3_args::data_type_enum::bf16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::bf16_t;
|
||||
using acc_dtype = float;
|
||||
@@ -43,14 +52,13 @@ struct unified_attention_problem_traits<unified_attention_args::data_type_enum::
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <unified_attention_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
|
||||
struct unified_attention_kernel_traits
|
||||
template <fmha_fwd_v3_args::data_type_enum DataType, bool IsVariableSeqlen, bool IsMasking>
|
||||
struct fmha_fwd_v3_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// M0 N0 K0 N1 K1
|
||||
static constexpr bool is_masking = IsMasking
|
||||
// M0 N0 K0 N1 K1
|
||||
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
|
||||
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using fmha_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
@@ -320,8 +320,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
args.data_type = problem.data_type;
|
||||
args.num_seqs = problem.batch;
|
||||
args.seqlen_q = problem.seqlen_q;
|
||||
args.seqlen_k = problem.seqlen_k;
|
||||
// args.seqlen_q = problem.seqlen_q;
|
||||
// args.seqlen_k = problem.seqlen_k;
|
||||
args.num_head_q = problem.nhead_q;
|
||||
args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv;
|
||||
args.mask_type = 2;
|
||||
@@ -332,7 +332,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
|
||||
// args.query_lens = problem.query_lens
|
||||
// args.kv_lens = problem.kv_lens
|
||||
|
||||
args.num_tokens = problem.batch * problem.seqlen_q;
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.query_stride_0 = problem.hdim * problem.nhead_q;
|
||||
args.query_stride_0 = problem.hdim;
|
||||
@@ -385,7 +385,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
|
||||
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
||||
};
|
||||
mask_type
|
||||
calculate_cumulative(eff_query_lens, cu_query_lens);
|
||||
|
||||
ck_tile::DeviceMem seq_lens_buf(kv_lens.size());
|
||||
|
||||
@@ -24,6 +24,7 @@ struct unified_attention_args
|
||||
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
|
||||
// window_size_right == 0).
|
||||
|
||||
index_t num_tokens; // total number of tokens in query
|
||||
index_t num_blks;
|
||||
index_t num_head_q;
|
||||
index_t num_queries_per_kv;
|
||||
|
||||
@@ -58,16 +58,16 @@ struct unified_attention_kernel_traits
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
// BLOCK_Q BLOCK_SIZE HEAD_SIZE N1 K1
|
||||
// BLOCK_Q BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<128, 128, 128>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
funified_attention_warp_gemm_shape,
|
||||
funified_attention_block_warps,
|
||||
funified_attention_warp_gemm_shape,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true // IsVLayoutRowMajor
|
||||
>;
|
||||
|
||||
@@ -77,9 +77,9 @@ struct unified_attention_kernel_traits
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
using funified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using funified_attention_pipeline_problem =
|
||||
using unified_attention_pipeline_problem =
|
||||
UnifiedAttentionPipelineProblem<typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
@@ -89,11 +89,11 @@ struct unified_attention_kernel_traits
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
funified_attention_shape,
|
||||
funified_attention_mask,
|
||||
funified_attention_traits>;
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using funified_attention_pipeline = BlockFunified_attentionFwdV3Pipeline<funified_attention_pipeline_problem>;
|
||||
using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
@@ -103,7 +103,7 @@ struct unified_attention_kernel_traits
|
||||
true // UseRawStore
|
||||
>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<funified_attention_pipeline, epilogue>;
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
@@ -140,7 +140,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const
|
||||
args.num_seqs
|
||||
);
|
||||
|
||||
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, args.total_num_q_blocks);
|
||||
index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs
|
||||
|
||||
|
||||
dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user