diff --git a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp index b067a9acae..457139fbff 100644 --- a/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp +++ b/example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp @@ -20,22 +20,31 @@ #include "fmha_fwd_v3.hpp" #include "mask.hpp" -#define INST_UNIFIED_ATTENTION_V3_DISPATCH(kernel_traits) \ +#define INST_FMHA_FWD_V3_DISPATCH(kernel_traits) \ template <> \ - std::pair unified_attention_kernel_dispatch( \ - const unified_attention_args& args, const stream_config& config) \ + std::pair fmha_fwd_v3_kernel_dispatch( \ + const fmha_fwd_v3_args& args, const stream_config& config) \ { \ return std::make_pair(true, \ - unified_attention_kernel_launch(args, config)); \ + fmha_fwd_v3_kernel_launch(args, config)); \ } namespace ck_tile { -template -struct unified_attention_problem_traits; +template +struct fmha_fwd_v3_problem_traits; template <> -struct unified_attention_problem_traits +struct fmha_fwd_v3_problem_traits +{ + using qkvp_dtype = ck_tile::half_t; + using acc_dtype = float; + using o_dtype = ck_tile::half_t; + using lse_dtype = float; +}; + +template <> +struct fmha_fwd_v3_problem_traits { using qkvp_dtype = ck_tile::bf16_t; using acc_dtype = float; @@ -43,14 +52,13 @@ struct unified_attention_problem_traits -struct unified_attention_kernel_traits +template +struct fmha_fwd_v3_kernel_traits { static constexpr auto date_type = DataType; static constexpr bool is_variable_seqlen = IsVariableSeqlen; - static constexpr bool is_masking = IsMasking; - - // M0 N0 K0 N1 K1 + static constexpr bool is_masking = IsMasking + // M0 N0 K0 N1 K1 using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>; using fmha_warp_gemm_shape = sequence<32, 32, 16>; using fmha_block_warps = sequence<8, 1, 1>; diff --git a/example/ck_tile/01_unified_attention/example_unified_attention.cpp b/example/ck_tile/01_unified_attention/example_unified_attention.cpp index 78c711804d..558dba164a 100644 --- a/example/ck_tile/01_unified_attention/example_unified_attention.cpp +++ b/example/ck_tile/01_unified_attention/example_unified_attention.cpp @@ -320,8 +320,8 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) args.data_type = problem.data_type; args.num_seqs = problem.batch; - args.seqlen_q = problem.seqlen_q; - args.seqlen_k = problem.seqlen_k; + // args.seqlen_q = problem.seqlen_q; + // args.seqlen_k = problem.seqlen_k; args.num_head_q = problem.nhead_q; args.num_queries_per_kv = problem.nhead_q / problem.nhead_kv; args.mask_type = 2; @@ -332,7 +332,7 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) // args.query_lens = problem.query_lens // args.kv_lens = problem.kv_lens - + args.num_tokens = problem.batch * problem.seqlen_q; args.q_ptr = q_buf.GetDeviceBuffer(); args.query_stride_0 = problem.hdim * problem.nhead_q; args.query_stride_0 = problem.hdim; @@ -385,7 +385,6 @@ bool run_impl(const Problem& problem, const RunConfig& run_config) for(std::size_t i = 0; i < per_batch_vec.size(); ++i) cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i]; }; -mask_type calculate_cumulative(eff_query_lens, cu_query_lens); ck_tile::DeviceMem seq_lens_buf(kv_lens.size()); diff --git a/example/ck_tile/01_unified_attention/unified_attention.hpp b/example/ck_tile/01_unified_attention/unified_attention.hpp index 25f787246b..cf616ef9a5 100644 --- a/example/ck_tile/01_unified_attention/unified_attention.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention.hpp @@ -24,6 +24,7 @@ struct unified_attention_args index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and // window_size_right == 0). + index_t num_tokens; // total number of tokens in query index_t num_blks; index_t num_head_q; index_t num_queries_per_kv; diff --git a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp index 5ac20d354e..4fa0bdab0d 100644 --- a/example/ck_tile/01_unified_attention/unified_attention_impl.hpp +++ b/example/ck_tile/01_unified_attention/unified_attention_impl.hpp @@ -58,16 +58,16 @@ struct unified_attention_kernel_traits static constexpr auto date_type = DataType; static constexpr bool is_masking = IsMasking; - // BLOCK_Q BLOCK_SIZE HEAD_SIZE N1 K1 + // BLOCK_Q BLOCK_SIZE HEAD_SIZE using unified_attention_block_tile = sequence<128, 128, 128>; using unified_attention_warp_gemm_shape = sequence<32, 32, 16>; using unified_attention_block_warps = sequence<8, 1, 1>; using unified_attention_shape = TileUnifiedAttentionShape; @@ -77,9 +77,9 @@ struct unified_attention_kernel_traits -1 // kBlockPerCu >; - using funified_attention_mask = GenericAttentionMask; + using unified_attention_mask = GenericAttentionMask; - using funified_attention_pipeline_problem = + using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem::qkvp_dtype, typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::qkvp_dtype, @@ -89,11 +89,11 @@ struct unified_attention_kernel_traits typename unified_attention_problem_traits::qkvp_dtype, typename unified_attention_problem_traits::acc_dtype, typename unified_attention_problem_traits::o_dtype, - funified_attention_shape, - funified_attention_mask, - funified_attention_traits>; + unified_attention_shape, + unified_attention_mask, + unified_attention_traits>; - using funified_attention_pipeline = BlockFunified_attentionFwdV3Pipeline; + using unified_attention_pipeline = Blockunified_attentionFwdV3Pipeline; using epilogue = Default2DEpilogue< Default2DEpilogueProblem::acc_dtype, @@ -103,7 +103,7 @@ struct unified_attention_kernel_traits true // UseRawStore >>; - using kernel = UnifiedAttentionKernel; + using kernel = UnifiedAttentionKernel; }; template @@ -140,7 +140,10 @@ float unified_attention_kernel_launch(const unified_attention_args& args, const args.num_seqs ); - dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, args.total_num_q_blocks); + index_t total_num_q_blocks = args.num_tokens / Kernel::BLOCK_Q + args.num_seqs + + + dim3 grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks); constexpr dim3 blocks = Kernel::BlockSize(); constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;