From dc4817921edda44a549197ff3a9dcf5df0636e7b Mon Sep 17 00:00:00 2001 From: Junkai-Wu Date: Thu, 12 Jun 2025 21:10:29 +0800 Subject: [PATCH] v4.0 update. (#2398) * Ex77 fix. --- .../77_blackwell_fmha/77_blackwell_fmha.cu | 91 +++++++++++++++++-- examples/77_blackwell_fmha/CMakeLists.txt | 57 ++++++++---- ..._fmha_fwd_mainloop_tma_warpspecialized.hpp | 20 +++- ...00_fmha_fwd_kernel_tma_warpspecialized.hpp | 1 + 4 files changed, 142 insertions(+), 27 deletions(-) diff --git a/examples/77_blackwell_fmha/77_blackwell_fmha.cu b/examples/77_blackwell_fmha/77_blackwell_fmha.cu index 8fc74c1fa..4f02f9e2d 100644 --- a/examples/77_blackwell_fmha/77_blackwell_fmha.cu +++ b/examples/77_blackwell_fmha/77_blackwell_fmha.cu @@ -116,6 +116,8 @@ struct Options { int h_k = 1; int q = 256; int k = 256; + std::vector varlen_q; + std::vector varlen_k; int d = 128; int warmup_iterations = 1; int iterations = 3; @@ -181,13 +183,76 @@ struct Options { cmd.get_cmd_line_argument("h_k", h_k, -1); if (h_k == -1) h_k = h; + varlen = cmd.check_cmd_line_flag("varlen"); + cmd.get_cmd_line_argument("q", q, -1); cmd.get_cmd_line_argument("k", k, -1); + cmd.get_cmd_line_argument("b", b, -1); + + std::string varlen_q_str; + cmd.get_cmd_line_argument("varlen-q", varlen_q_str); + std::string varlen_k_str; + cmd.get_cmd_line_argument("varlen-k", varlen_k_str); + + if (varlen && ! varlen_q_str.empty()) { + varlen_q.clear(); + while (! varlen_q_str.empty()) { + size_t pos = varlen_q_str.find(':'); + varlen_q.push_back(std::stoi(varlen_q_str.substr(0, pos))); + if (pos == std::string::npos) { + break; + } + varlen_q_str = varlen_q_str.substr(pos + 1); + } + if (b == -1) { + b = static_cast(varlen_q.size()); + } + if (b != static_cast(varlen_q.size())) { + std::cout << "Error: Invalid --varlen-q length\n"; + std::exit(-1); + } + int new_q = 0; + for (auto elem : varlen_q) { + new_q += elem; + } + if (q != -1) { + std::cout << "Error: Can't provide --q and --varlen-q\n"; + std::exit(-1); + } + q = new_q; + } + + if (varlen && ! varlen_k_str.empty()) { + varlen_k.clear(); + while (! varlen_k_str.empty()) { + size_t pos = varlen_k_str.find(':'); + varlen_k.push_back(std::stoi(varlen_k_str.substr(0, pos))); + if (pos == std::string::npos) { + break; + } + varlen_k_str = varlen_k_str.substr(pos + 1); + } + if (b == -1) { + b = static_cast(varlen_k.size()); + } + if (b != static_cast(varlen_k.size())) { + std::cout << " Error: Invalid --varlen-k length\n"; + std::exit(-1); + } + int new_k = 0; + for (auto elem : varlen_k) { + new_k += elem; + } + if (k != -1) { + std::cout << "Error: Can't provide --k and --varlen-k\n"; + std::exit(-1); + } + k = new_k; + } + if (q == -1) q = k; if (k == -1) k = q; if (q == -1 && k == -1) q = k = defaults.q; - - cmd.get_cmd_line_argument("b", b, -1); if (b == -1) b = 16384 / k; if (b == 0) b = 1; @@ -197,7 +262,6 @@ struct Options { verify = cmd.check_cmd_line_flag("verify"); verbose = cmd.check_cmd_line_flag("verbose"); - varlen = cmd.check_cmd_line_flag("varlen"); persistent = cmd.check_cmd_line_flag("persistent"); std::string mask; @@ -240,7 +304,9 @@ struct Options { << " --h_k= Sets the H_K/V extent (for GQA/MQA)\n" << " --q= Sets the Q extent\n" << " --k= Sets the K extent\n" - << " --d= Sets the D extentn" + << " --varlen-q=: Sets the variable Q extent per batch (colon separated)\n" + << " --varlen-k=: Sets the variable K extent per batch (colon separated)\n" + << " --d= Sets the D extent\n" << " --tensor_ring_buffers= Sets the number of tensor ring buffers\n" << " --warmup_iterations= Sets the warmup iterations\n" << " --iterations= Benchmarking iterations\n" @@ -475,7 +541,10 @@ struct FwdRunner { } template - auto initialize_varlen(const ProblemShape& problem_size, const bool kVarlenSame = true) { + auto initialize_varlen( + const Options& options, const ProblemShape& problem_size, + const bool kVarlenSame = true) { + int num_batches = get<3,1>(problem_size); // generate Q as --b times @@ -503,8 +572,12 @@ struct FwdRunner { int max_seqlen_kv = 0; for (int i = 0; i < num_batches; i++) { - int seqlen_q = kVarlenSame ? get<0>(problem_size) : generate_positive_int(dist_q, rng); - int seqlen_kv = kVarlenSame ? get<1>(problem_size) : generate_positive_int(dist_kv, rng); + int seqlen_q = (! options.varlen_q.empty()) ? options.varlen_q.at(i) : + kVarlenSame ? get<0>(problem_size) : + generate_positive_int(dist_q, rng); + int seqlen_kv = (! options.varlen_k.empty()) ? options.varlen_k.at(i) : + kVarlenSame ? get<1>(problem_size) : + generate_positive_int(dist_kv, rng); total_seqlen_q += seqlen_q; total_seqlen_kv += seqlen_kv; @@ -545,7 +618,7 @@ struct FwdRunner { decltype(problem_shape_in) problem_size; if constexpr (kIsVarlen) { - auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(options, problem_shape_in); problem_shape = problem_shape_launch; problem_size = problem_shape_init; } @@ -588,6 +661,8 @@ struct FwdRunner { buffer.block_V.reset(size(shape_KV), kIsVarlen ? D*SK*H_K : 0); buffer.block_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); buffer.block_LSE.reset(size(shape_LSE)); + buffer.block_ref_O.reset(size(shape_QO), kIsVarlen ? D*SQ*H : 0); + buffer.block_ref_LSE.reset(size(shape_LSE)); initialize_block(buffer.block_Q, seed + 2023, options.init_style_q); initialize_block(buffer.block_K, seed + 2022, options.init_style_k); diff --git a/examples/77_blackwell_fmha/CMakeLists.txt b/examples/77_blackwell_fmha/CMakeLists.txt index 8a30510be..fddff0c72 100644 --- a/examples/77_blackwell_fmha/CMakeLists.txt +++ b/examples/77_blackwell_fmha/CMakeLists.txt @@ -43,6 +43,22 @@ set(TEST_VARLEN --b=1 --h=4 --q=512 --k=512 --d=128 --verify --mask=residual --v set(TEST_HDIM64 --b=2 --h=4 --q=512 --k=512 --d=64 --verify) set(TEST_GQA --b=2 --h=4 --h_k=2 --q=512 --k=512 --d=64 --verify) +set(TEST_VARLEN_00 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=128 --varlen-k=128) +set(TEST_VARLEN_01 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=128 --varlen-k=128) +set(TEST_VARLEN_02 --verify --mask=causal,residual --d=128 --h=4 --h_k=2 --varlen-q=128 --varlen-k=128) +set(TEST_VARLEN_03 --verify --mask=causal,residual --d=128 --h=8 --h_k=8 --varlen-q=256:256 --varlen-k=512:512) +set(TEST_VARLEN_04 --verify --mask=causal,residual --d=128 --h=8 --h_k=4 --varlen-q=256:256 --varlen-k=512:512) +set(TEST_VARLEN_05 --verify --mask=causal,residual --d=128 --h=8 --h_k=1 --varlen-q=256:256 --varlen-k=512:512) +set(TEST_VARLEN_06 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:768:512:512) +set(TEST_VARLEN_07 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:256:256:256 --varlen-k=256:0:1280:512) +set(TEST_VARLEN_08 --verify --mask=causal,residual --d=128 --h=8 --h_k=2 --varlen-q=256:0:512:256 --varlen-k=256:256:1024:512) +set(TEST_VARLEN_09 --verify --mask=causal,residual --d=64 --h=16 --h_k=16 --varlen-q=100:300 --varlen-k=100:300) +set(TEST_VARLEN_10 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=3:2 --varlen-k=2:5) +set(TEST_VARLEN_11 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=17:10 --varlen-k=13:10) +set(TEST_VARLEN_12 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=177:845 --varlen-k=257:766) +set(TEST_VARLEN_13 --verify --mask=causal,residual --d=64 --h=4 --h_k=2 --varlen-q=177:366:479 --varlen-k=257:0:766) +set(TEST_VARLEN_14 --verify --mask=causal,residual --d=64 --h=4 --h_k=4 --varlen-q=1 --varlen-k=1) + set(TEST_GEN_BASIC --b=1 --h=4 --k=512 --d=128 --verify) set(TEST_GEN_VARLEN --b=1 --h=4 --k=512 --d=128 --verify --varlen) set(TEST_GEN_HDIM64 --b=2 --h=4 --k=512 --d=64 --verify) @@ -62,10 +78,25 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha.cu TEST_COMMAND_OPTIONS TEST_BASIC - # TEST_CAUSAL - # TEST_VARLEN - # TEST_HDIM64 - # TEST_GQA) + TEST_CAUSAL + TEST_VARLEN + TEST_HDIM64 + TEST_GQA + TEST_VARLEN_00 + TEST_VARLEN_01 + TEST_VARLEN_02 + TEST_VARLEN_03 + TEST_VARLEN_04 + TEST_VARLEN_05 + TEST_VARLEN_06 + TEST_VARLEN_07 + TEST_VARLEN_08 + TEST_VARLEN_09 + TEST_VARLEN_10 + TEST_VARLEN_11 + TEST_VARLEN_12 + TEST_VARLEN_13 + TEST_VARLEN_14 ) target_include_directories(77_blackwell_fmha_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_${PREC} PRIVATE ${PREC_MACRO}) @@ -75,11 +106,11 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha_gen.cu TEST_COMMAND_OPTIONS TEST_GEN_BASIC - # TEST_GEN_VARLEN - # TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) + TEST_GEN_VARLEN + TEST_GEN_HDIM64 + TEST_GEN_GQA + TEST_GEN_REMAP + TEST_GEN_CACHEONLY ) target_include_directories(77_blackwell_fmha_gen_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_gen_${PREC} PRIVATE ${PREC_MACRO}) @@ -119,11 +150,7 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_fmha_bwd.cu TEST_COMMAND_OPTIONS TEST_BASIC - # TEST_GEN_VARLEN - # TEST_GEN_HDIM64 - # TEST_GEN_GQA - # TEST_GEN_REMAP - # TEST_GEN_CACHEONLY) + TEST_VARLEN ) target_include_directories(77_blackwell_fmha_bwd_${PREC} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_compile_definitions(77_blackwell_fmha_bwd_${PREC} PRIVATE ${PREC_MACRO}) @@ -160,7 +187,5 @@ if(NOT WIN32 AND (NOT (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) AND (CUTLASS_NVCC 77_blackwell_mla_b2b_2sm_fp16 77_blackwell_fmha_bwd_fp8 77_blackwell_fmha_bwd_fp16 - 77_blackwell_fmha_bwd_sat_fp8 - 77_blackwell_fmha_bwd_sat_fp16 ) endif() diff --git a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp index 58102767f..91f2ee1e8 100644 --- a/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -942,11 +942,15 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { } } - template + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > CUTLASS_DEVICE auto correction( BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, TensorStorageEpi& shared_storage_epi, PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, @@ -1068,10 +1072,15 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { if (epilogue.params.ptr_LSE != nullptr) { int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); if (row_idx < get<0>(problem_shape)) { - gLSE(row_idx, get<2>(blk_coord)) = lse; + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } @@ -1101,8 +1110,13 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + if (row_idx < get<0>(problem_shape)) { - gLSE(row_idx, get<2>(blk_coord)) = lse; + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; } } diff --git a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp index e297e7312..ada3ee0b8 100644 --- a/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp +++ b/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -403,6 +403,7 @@ struct Sm100FmhaFwdKernelTmaWarpspecialized { mainloop.correction( blk_coord, params.mainloop, logical_problem_shape, + params.problem_shape, shared_storage.epilogue, pipeline_s0_corr, pipeline_s0_corr_consumer_state, pipeline_s1_corr, pipeline_s1_corr_consumer_state,