Flash MLA support (#2130)

* initial commit

* initial commit

* fix some error

* update

* bugfix

* bugfix

* change name
This commit is contained in:
myu-guo
2025-02-24 21:31:56 +08:00
committed by GitHub
parent eefa171318
commit 415d587ebf
9 changed files with 1777 additions and 0 deletions

View File

@@ -0,0 +1,475 @@
/***************************************************************************************************
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief A Hopper CUTLASS example for Flash MLA.
*/
#include <cassert>
#include "flash_fwd_mla_kernel.h"
#include "flash_mla.h"
#include <thrust/universal_vector.h>
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include "cutlass/util/command_line.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/tensor_view_io.h"
#include "cutlass/util/distribution.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/util/reference/host/tensor_copy.h"
#include "cutlass/util/reference/host/tensor_compare.h"
#include "cutlass/util/reference/host/tensor_norm.h"
#include "cutlass/util/reference/device/tensor_fill.h"
#include <cuda_runtime.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
#define CUDA_CHECK(status) \
{ \
cudaError_t error = status; \
if (error != cudaSuccess) { \
std::cerr << "CUDA error: " << cudaGetErrorString(error) << " at " << \
__FILE__ << ":" << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
} \
}
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Testbed utility types
/////////////////////////////////////////////////////////////////////////////////////////////////
// Command line options parsing
struct Options {
bool help;
int iterations;
int b, s, h_q, s_q;
int h_kv, d, dv;
float softmax_scale;
bool varlen;
bool causal;
static constexpr int block_size = 64;
Options():
help(false),
b(128), s(4096), h_q(16), s_q(1),
h_kv(1), d(576), dv(512),
varlen(false),
causal(true),
iterations(10)
{ }
// Parses the command line
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
Options defaults;
if (cmd.check_cmd_line_flag("help")) {
help = true;
return;
}
cmd.get_cmd_line_argument("b", b, defaults.b);
cmd.get_cmd_line_argument("s", s, defaults.s);
cmd.get_cmd_line_argument("h_q", h_q, defaults.h_q);
cmd.get_cmd_line_argument("s_q", s_q, defaults.s_q);
cmd.get_cmd_line_argument("h_kv", h_kv, defaults.h_kv);
cmd.get_cmd_line_argument("d", d, defaults.d);
cmd.get_cmd_line_argument("dv", dv, defaults.dv);
if (cmd.check_cmd_line_flag("varlen")) {
varlen = true;
}
cmd.get_cmd_line_argument("iterations", iterations, defaults.iterations);
softmax_scale = 1 / std::sqrt(d);
}
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {
out << "79_hopper_flash_mla\n\n"
<< " Hopper Flash MLA kernel.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement\n\n"
<< " --b=<int> Sets the batch size\n"
<< " --s=<int> Sets the sequence length\n"
<< " --h_q=<int> Sets the number of heads\n"
<< " --s_q=<int> Sets the sequence length of the query\n"
<< " --varlen Sets the varlen as true or false\n"
<< " --iterations=<int> Number of profiling iterations to perform.\n\n";
return out;
}
/// TOOD:Compute performance in GFLOP/s
double gflops(double runtime_s) const
{
// Two flops per multiply-add
// uint64_t flop = uint64_t(2) * m * n * k;
// double gflop = double(flop) / double(1.0e9);
// return gflop / runtime_s;
}
};
/// Helper to initialize a block of device data
template <typename Element>
static void
initialize_values(
thrust::universal_vector<Element>& dst_ptr,
cutlass::Distribution::Kind dist_kind,
uint64_t seed,
Element var = Element(1.f)) {
if (cutlass::Distribution::Uniform == dist_kind) {
int scope = 2;
cutlass::reference::host::BlockFillRandomUniform(
dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0);
}
else if (cutlass::Distribution::AllZeros == dist_kind) {
cutlass::reference::host::BlockFillRandomUniform(
dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0);
}
else if (cutlass::Distribution::AllOnes == dist_kind) {
cutlass::reference::host::BlockFillRandomUniform(
dst_ptr.data().get(), dst_ptr.size(), seed, 1, 1, 0);
}
else if (cutlass::Distribution::Gaussian == dist_kind) {
cutlass::reference::device::BlockFillRandomGaussian(
dst_ptr.data().get(), dst_ptr.size(), seed, (Element) 0, var);
}
else if (cutlass::Distribution::Sequential == dist_kind) {
cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size());
}
else {
std::cerr << "Invalid distribution kind!\n.";
exit(1);
}
}
void initialize_varlen(thrust::universal_vector<int32_t>& block_C, const Options &options) {
block_C.resize(options.b);
std::vector<int32_t> cache_seqlens(options.b, options.s);
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> distribution(options.s, options.s / 2.0f);
for (int i = 0; i < options.b; ++i) {
if (options.varlen) {
float random_length = distribution(gen);
cache_seqlens[i] = std::max(static_cast<int32_t>(random_length), options.s_q);
} else {
cache_seqlens[i] = options.s;
}
}
cutlass::DeviceAllocation<int32_t> d_cache_seqlens(options.b);
CUDA_CHECK(cudaMemcpy(
block_C.data().get(),
cache_seqlens.data(),
options.b * sizeof(int32_t),
cudaMemcpyHostToDevice
));
}
auto initialize_metadata(
thrust::universal_vector<int32_t> &block_C,
thrust::universal_vector<int32_t> &block_MD, thrust::universal_vector<int32_t> &block_S,
int& num_sm_parts,
const Options &options) {
// This should match the logic in the MLA kernel.
static constexpr int block_size_m = 64;
static constexpr int block_size_n = 64;
static constexpr int fixed_overhead_num_blocks = 5;
cudaDeviceProp props;
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
auto batch_size = options.b;
int sm_count = props.multiProcessorCount;
num_sm_parts = sm_count / options.h_kv / cutlass::ceil_div(options.h_kv, block_size_m);
block_MD.resize(num_sm_parts * TileSchedulerMetaDataSize);
block_S.resize(options.b + 1);
Mla_metadata_params params{};
params.seqlens_k_ptr = block_C.data().get();
params.tile_scheduler_metadata_ptr = block_MD.data().get();
params.num_splits_ptr = block_S.data().get();
params.batch_size = batch_size;
params.block_size_n = block_size_n;
params.fixed_overhead_num_blocks = fixed_overhead_num_blocks;
params.num_sm_parts = num_sm_parts;
cudaStream_t stream{nullptr};
get_mla_metadata_func(params, stream);
}
struct TestBed {
using Element = cutlass::bfloat16_t;
using ElementAcc = float;
thrust::universal_vector<Element> block_Q; // query
thrust::universal_vector<Element> block_K; // blocked key
thrust::universal_vector<int32_t> block_T; // block table
thrust::universal_vector<int32_t> block_C; // cache seqlens
// TODO: block_V is not used in the example
// thrust::universal_vector<Element> block_V; // dv
thrust::universal_vector<int32_t> block_MD; // mla metadata
thrust::universal_vector<int32_t> block_S; // num splits
thrust::universal_vector<Element> block_O; // output
thrust::universal_vector<Element> block_LSE; // lse
thrust::universal_vector<ElementAcc> block_O_Accum; // output
thrust::universal_vector<ElementAcc> block_LSE_Accum; // lse
/// Initialize operands to be used in the GEMM and reference GEMM
void initialize(
const Options &options,
int& total_blocks, int& blocks_per_seq, int& num_sm_parts,
uint64_t seed = 2025) {
initialize_varlen(block_C, options);
thrust::device_ptr<int32_t> d_ptr(block_C.data().get());
int64_t total_seqlens = thrust::reduce(d_ptr, d_ptr + options.b);
float sum = static_cast<float>(total_seqlens);
int32_t mean_seqlens = static_cast<int32_t>(sum / options.b);
int32_t max_seqlen = thrust::reduce(d_ptr, d_ptr + options.b,
0,
thrust::maximum<int32_t>());
int max_seqlen_pad = ((max_seqlen + 255) / 256) * 256;
blocks_per_seq = max_seqlen_pad / options.block_size;
total_blocks = options.b * blocks_per_seq;
// Query: [b, s_q, h_q, d]
block_Q.resize(options.b * options.s_q * options.h_q * options.d);
// Block table: [b, max_num_blocks_per_seq]
block_T.resize(total_blocks);
// Key: [b, max_num_blocks_per_seq, block_size, h_kv, d]
block_K.resize(total_blocks * options.block_size * options.h_kv * options.d);
initialize_values(block_Q, cutlass::Distribution::Gaussian, seed + 1);
initialize_values(block_T, cutlass::Distribution::Sequential, seed + 3);
initialize_values(block_K, cutlass::Distribution::Gaussian, seed + 5);
// TODO: Set the exceeding part to NaN
initialize_metadata(block_C, block_MD, block_S, num_sm_parts, options);
int ngroups = options.h_q / options.h_kv;
int num_heads = options.h_kv;
int seqlen_q = options.s_q * ngroups;
// LSE: [batch_size, num_heads, seqlen_q]
block_LSE.resize(options.b * num_heads * seqlen_q);
// Output: [batch_size, seqlen_q, num_heads, head_size_v]
block_O.resize(options.b * seqlen_q * num_heads * options.dv);
auto softmax_lse_size = (options.b + num_sm_parts) * num_heads * seqlen_q;
auto out_accum_size = (options.b + num_sm_parts) * num_heads * seqlen_q * options.dv;
block_LSE_Accum.resize(softmax_lse_size);
block_O_Accum.resize(out_accum_size);
}
/// Execute a given example Flash MLA computation
void run(Options &options)
{
cudaDeviceProp props;
int current_device;
CUDA_CHECK(cudaGetDevice(&current_device));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device));
// TODO: use vcache which is None in the example
auto batch_size = options.b;
auto seqlen_q_ori = options.s_q;
auto num_heads_ori = options.h_q;
auto head_size = options.d;
auto head_size_v = options.dv;
auto num_heads_k = options.h_kv;
auto page_block_size = options.block_size;
int total_blocks, max_num_blocks_per_seq;
int num_sm_parts;
assert(head_size % 8 == 0);
assert(head_size_v % 32 == 0);
initialize(options, total_blocks, max_num_blocks_per_seq, num_sm_parts);
assert(batch_size > 0);
assert(num_heads_ori % num_heads_k == 0);
bool is_causal = seqlen_q_ori == 1 ? false : options.causal;
int ngroups = num_heads_ori / num_heads_k;
int seqlen_q = seqlen_q_ori * ngroups;
int num_heads = num_heads_k;
// TODO: preprocess the query
// q = q.view({batch_size, seqlen_q_ori, num_heads_k, ngroups, head_size}).transpose(2, 3).reshape({batch_size, seqlen_q, num_heads, head_size});
cudaStream_t stream{nullptr};
// set the parameters
Flash_fwd_mla_params kernel_params{};
kernel_params.b = options.b;
kernel_params.seqlen_q = options.s_q;
kernel_params.d = options.d;
kernel_params.d_v = options.dv;
kernel_params.h = options.h_q;
kernel_params.h_h_k_ratio = num_heads_ori / num_heads_k;
kernel_params.ngroups = ngroups;
kernel_params.q_ptr = block_Q.data().get();
kernel_params.k_ptr = block_K.data().get();
// TODO: block_V is not used in the example
kernel_params.v_ptr = block_K.data().get();
kernel_params.o_ptr = block_O.data().get();
kernel_params.softmax_lse_ptr = block_LSE.data().get();
kernel_params.q_batch_stride = seqlen_q * num_heads * options.d;
kernel_params.k_batch_stride = page_block_size * options.h_kv * options.d;
kernel_params.v_batch_stride = page_block_size * options.h_kv * options.dv;
kernel_params.o_batch_stride = options.s_q * options.h_q * options.dv;
kernel_params.q_row_stride = num_heads * options.d;
kernel_params.k_row_stride = options.h_kv * options.d;
kernel_params.v_row_stride = options.h_kv * options.dv;
kernel_params.o_row_stride = options.h_q * options.dv;
kernel_params.q_head_stride = options.d;
kernel_params.k_head_stride = options.d;
kernel_params.v_head_stride = options.dv;
kernel_params.o_head_stride = options.dv;
kernel_params.block_table = block_T.data().get();
kernel_params.block_table_batch_stride = max_num_blocks_per_seq;
kernel_params.page_block_size = page_block_size;
kernel_params.tile_scheduler_metadata_ptr = block_MD.data().get();
kernel_params.num_splits_ptr = block_S.data().get();
kernel_params.softmax_lseaccum_ptr = block_LSE_Accum.data().get();
kernel_params.oaccum_ptr = block_O_Accum.data().get();
kernel_params.is_causal = is_causal;
kernel_params.scale_softmax = options.softmax_scale;
kernel_params.scale_softmax_log2 = std::log2(options.softmax_scale);
kernel_params.cu_seqlens_k = block_C.data().get();
kernel_params.num_sm_parts = num_sm_parts;
assert(head_size == 576);
run_mha_fwd_splitkv_mla<cutlass::bfloat16_t, 576>(kernel_params, stream);
CUDA_CHECK(cudaDeviceSynchronize());
// TODO: postprocess the output
// out = out.view({batch_size, seqlen_q_ori, ngroups, num_heads_k, head_size_v}).transpose(2, 3)
// .reshape({batch_size, seqlen_q_ori, num_heads_ori, head_size_v});
// softmax_lse = softmax_lse.view({batch_size, num_heads_k, seqlen_q_ori, ngroups}).transpose(2, 3)
// .reshape({batch_size, num_heads_ori, seqlen_q_ori});
// TODO: reference check
printf("run done\n");
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////
int main(int argc, char const **args) {
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example
// and must have compute capability at least 100a.
if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 8)) {
std::cerr << "This example requires CUDA 12.8 or newer." << std::endl;
// Returning zero so this test passes on older Toolkits. Its actions are no-op.
return 0;
}
cudaDeviceProp props;
int current_device_id;
CUDA_CHECK(cudaGetDevice(&current_device_id));
CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id));
cudaError_t error = cudaGetDeviceProperties(&props, 0);
if (props.major != 9 || props.minor != 0) {
std::cerr << "This example requires a GPU with compute capability 90)." << std::endl;
return 0;
}
//
// Parse options
//
Options options;
options.parse(argc, args);
if (options.help) {
options.print_usage(std::cout) << std::endl;
return 0;
}
//
// Evaluate CUTLASS kernels
//
TestBed testbed{};
testbed.run(options);
return 0;
}
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -0,0 +1,36 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# Sparse kernel in this example triggers an ICE in gcc 7.5
if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0))
cutlass_example_add_executable(
68_hopper_flash_mla
68_hopper_flash_mla.cu
)
endif()

View File

@@ -0,0 +1,681 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/flash_fwd_mla_kernel.h
#pragma once
#include <cute/tensor.hpp>
#include <cutlass/cutlass.h>
#include <cutlass/array.h>
#include <cutlass/numeric_types.h>
using namespace cute;
#include "named_barrier.h"
#include "utils.h"
#include "softmax.h"
#include "static_switch.h"
#include "flash_mla.h"
template<typename PrecType, int DIM, int DIM2 = DIM>
constexpr auto getSmemLayoutK() {
constexpr int headSizeBytes = sizeof(PrecType) * DIM;
constexpr int headSizeBytes2 = sizeof(PrecType) * DIM2;
if constexpr (headSizeBytes % 128 == 0 && headSizeBytes2 % 128 == 0) {
return GMMA::Layout_K_SW128_Atom<PrecType>{};
} else if constexpr (headSizeBytes % 64 == 0 && headSizeBytes2 % 64 == 0) {
return GMMA::Layout_K_SW64_Atom<PrecType>{};
} else {
return GMMA::Layout_K_SW32_Atom<PrecType>{};
}
}
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, typename elem_type=cutlass::bfloat16_t, int kHeadDimV_ = 0>
struct Flash_fwd_kernel_traits_mla {
using Element = elem_type;
using ElementAccum = float;
using index_t = int64_t;
static constexpr int kNWarps = kNWarps_;
static constexpr int kNThreads = kNWarps * 32;
static constexpr int kNWarpsS = 4;
static constexpr int kNThreadsS = kNWarpsS * 32;
static constexpr int kBlockM = kBlockM_;
static constexpr int kBlockN = kBlockN_;
static constexpr int kHeadDim = kHeadDim_;
static_assert(kHeadDim % 32 == 0);
static constexpr int kHeadDimV = kHeadDimV_ != 0 ? kHeadDimV_ : kHeadDim;
static_assert(kHeadDimV % 32 == 0);
static_assert(kHeadDimV <= kHeadDim);
static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32;
static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3;
using TiledMma = decltype(make_tiled_mma(
cute::GMMA::ss_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kBlockN>, Int<kHeadDim>>,
GMMA::Major::K, GMMA::Major::K>(),
Layout<Shape<Int<kNWarpsS / 4>, _1, _1>>{}));
static constexpr int AtomLayoutNO = kNThreads / kNThreadsS;
using TiledMmaO = decltype(make_tiled_mma(
cute::GMMA::rs_op_selector<Element, Element, ElementAccum, Shape<Int<kBlockM>, Int<kHeadDimV / AtomLayoutNO>, Int<kBlockN>>,
GMMA::Major::K, GMMA::Major::MN>(),
Layout<Shape<Int<kNWarpsS / 4>, Int<AtomLayoutNO>, _1>>{}));
using SmemLayoutQ = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim>(),
Shape<Int<kBlockM>, Int<kHeadDim>>{}));
using SmemLayoutK = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kBlockN>, Int<kHeadDim>>{}));
using SmemLayoutV = decltype(tile_to_shape(
getSmemLayoutK<Element, kHeadDim, kHeadDimV>(),
Shape<Int<kBlockN>, Int<kHeadDimV>>{}));
using SmemLayoutVtransposed = decltype(composition(SmemLayoutV{}, make_layout(Shape<Int<kHeadDimV>, Int<kBlockN>>{}, GenRowMajor{})));
using SmemLayoutP = Layout<Shape<Shape<_2, _2>, Int<kNThreadsS>, _1, Int<kBlockN / 8>>>;
using SmemLayoutRow = Layout<Shape<_2, Int<kNThreadsS>>, Stride<_1, _2>>;
using SmemLayoutAtomO = decltype(composition(
Swizzle<kSwizzle, 3, 3>{},
Layout<Shape<Int<8>, Int<kBlockKSmem>>, Stride<Int<kBlockKSmem>, _1>>{}));
using SmemLayoutO = decltype(tile_to_shape(
SmemLayoutAtomO{},
Shape<Int<kBlockM>, Int<kHeadDimV>>{}));
using SmemCopyAtomO = Copy_Atom<SM90_U32x4_STSM_N, Element>;
using SmemCopyAtomOaccum = Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>;
static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element);
static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad");
static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad;
using Gmem_copy_struct = SM80_CP_ASYNC_CACHEGLOBAL<cute::uint128_t>;
static constexpr int kNThreadsLoad = kNThreads - kNThreadsS;
static_assert(kNThreadsLoad % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow");
using GmemLayoutAtom = Layout<
Shape<Int<kNThreadsLoad / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopy = decltype(make_tiled_copy(
Copy_Atom<Gmem_copy_struct, Element>{},
GmemLayoutAtom{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per read
using GmemLayoutAtomO = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyO = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, Element>{},
GmemLayoutAtomO{},
Layout<Shape<_1, _8>>{})); // Val layout, 8 vals per store
static constexpr int kGmemElemsPerLoadAccum = sizeof(cute::uint128_t) / sizeof(ElementAccum);
static constexpr int kGmemThreadsPerRowAccum = kBlockKSmem / kGmemElemsPerLoadAccum;
using GmemLayoutAtomOaccum = Layout<
Shape<Int<kNThreadsS / kGmemThreadsPerRowAccum>, Int<kGmemThreadsPerRowAccum>>,
Stride<Int<kGmemThreadsPerRowAccum>, _1>>;
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
GmemLayoutAtomOaccum{},
Layout<Shape<_1, _4>>{})); // Val layout, 4 vals per store
};
namespace flash {
using namespace cute;
template<typename Kernel_traits>
struct SharedStorageMLA {
union {
struct {
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutQ>> smem_q;
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutK> * 2> smem_k; // Double buffer
cute::array_aligned<typename Kernel_traits::Element, cute::cosize_v<typename Kernel_traits::SmemLayoutP>> smem_p;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_scale;
};
struct {
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_max;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutRow>> smem_sum;
cute::array_aligned<typename Kernel_traits::ElementAccum, cute::cosize_v<typename Kernel_traits::SmemLayoutO>> smem_o;
};
};
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Split, typename SharedStorage, typename AccO, typename Softmax>
__forceinline__ __device__ void store(const Flash_fwd_mla_params &params, const int bidb, const int bidh, const int m_block, const int n_split_idx,
SharedStorage &shared_storage, AccO tOrO, Softmax softmax) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
using Element = typename Kernel_traits::Element;
using ElementAccum = typename Kernel_traits::ElementAccum;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
// Epilogue
const int split_offset = __ldg(params.num_splits_ptr + bidb);
Tensor lse = softmax.template normalize_softmax_lse</*Is_dropout=*/false, Split>(tOrO, params.scale_softmax);
using ElementO = std::conditional_t<!Split, Element, ElementAccum>;
Tensor sOaccum = make_tensor(make_smem_ptr(reinterpret_cast<ElementO *>(shared_storage.smem_o.data())), typename Kernel_traits::SmemLayoutO{}); // (SMEM_M,SMEM_N)
// Partition sO to match the accumulator partitioning
using SmemTiledCopyO = std::conditional_t<
!Split,
typename Kernel_traits::SmemCopyAtomO,
typename Kernel_traits::SmemCopyAtomOaccum
>;
auto smem_tiled_copy_Oaccum = make_tiled_copy_C(SmemTiledCopyO{}, tiled_mma_o);
auto smem_thr_copy_Oaccum = smem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor rO = flash::convert_type<ElementO>(tOrO);
Tensor taccOrOaccum = smem_thr_copy_Oaccum.retile_S(rO); // ((Atom,AtomNum), MMA_M, MMA_N)
Tensor taccOsOaccum = smem_thr_copy_Oaccum.partition_D(sOaccum); // ((Atom,AtomNum),PIPE_M,PIPE_N)
__syncthreads();
cute::copy(smem_tiled_copy_Oaccum, taccOrOaccum, taccOsOaccum);
const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
const index_t row_offset_oaccum = (((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_v;
const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
const index_t row_offset_lseaccum = ((split_offset + n_split_idx) * params.h + bidh) * params.seqlen_q + m_block * kBlockM;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)),
Shape<Int<kBlockM>, Int<kHeadDimV>>{},
make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (Split ? row_offset_lseaccum : row_offset_lse)),
Shape<Int<kBlockM>>{}, Stride<_1>{});
using GmemTiledCopyO = std::conditional_t<!Split, typename Kernel_traits::GmemTiledCopyO, typename Kernel_traits::GmemTiledCopyOaccum>;
GmemTiledCopyO gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOsOaccum = gmem_thr_copy_Oaccum.partition_S(sOaccum); // ((Atom,AtomNum),ATOM_M,ATOM_N)
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_D(gOaccum);
__syncthreads();
if (tidx >= kNThreadsS) { return; }
Tensor tOrOaccum = make_tensor<ElementO>(shape(tOgOaccum));
cute::copy(gmem_tiled_copy_Oaccum, tOsOaccum, tOrOaccum);
Tensor caccO = make_identity_tensor(Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor taccOcO = thr_mma_o.partition_C(caccO); // ((MMA=4, X), MMA_M, MMA_K=1)
Tensor taccOcO_row = taccOcO(make_coord(0, _, 0), _, 0);
CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M
if (get<1>(taccOcO_row(0)) == 0) {
#pragma unroll
for (int mi = 0; mi < size(lse); ++mi) {
const int row = get<0>(taccOcO_row(mi));
if (row < params.seqlen_q - m_block * kBlockM) { gLSEaccum(row) = lse(mi); }
}
}
// Construct identity layout for sO
Tensor cO = make_identity_tensor(make_shape(size<0>(sOaccum), size<1>(sOaccum))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
// Repeat the partitioning with identity layouts
Tensor tOcO = gmem_thr_copy_Oaccum.partition_D(cO); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tOpO = make_tensor<bool>(make_shape(size<2>(tOgOaccum)));
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_tiled_copy_Oaccum, tOrOaccum, tOgOaccum, tOcO, tOpO, params.seqlen_q - m_block * kBlockM
);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla(const Flash_fwd_mla_params &params,
const int bidb, const int bidh, const int m_block,
const int n_split_idx, const int seqlen_k,
const int n_block_min, const int n_block_max, const bool NoSplit,
SharedStorage &shared_storage) {
constexpr int kBlockM = Kernel_traits::kBlockM;
constexpr int kBlockN = Kernel_traits::kBlockN;
constexpr int kHeadDim = Kernel_traits::kHeadDim;
constexpr int kHeadDimV = Kernel_traits::kHeadDimV;
constexpr int kNThreads = Kernel_traits::kNThreads;
constexpr int kNThreadsS = Kernel_traits::kNThreadsS;
static_assert(kNThreads == 256 and kNThreadsS == 128);
using Element = typename Kernel_traits::Element;
using index_t = typename Kernel_traits::index_t;
const int tidx = threadIdx.x;
int n_block = n_block_max - 1;
Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Kernel_traits::SmemLayoutQ{});
Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutK{});
Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutV{});
Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Kernel_traits::SmemLayoutVtransposed{});
Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Kernel_traits::SmemLayoutP{});
Tensor tPsP = sP(_, tidx % kNThreadsS, _, _);
Tensor sScale_o = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tScale_osScale_o = sScale_o(_, tidx % kNThreadsS);
Tensor sRow_max = make_tensor(make_smem_ptr(shared_storage.smem_max.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tRow_maxsRow_max = sRow_max(_, tidx % kNThreadsS);
Tensor sRow_sum = make_tensor(make_smem_ptr(shared_storage.smem_sum.data()), typename Kernel_traits::SmemLayoutRow{});
Tensor tRow_sumsRow_sum = sRow_sum(_, tidx % kNThreadsS);
typename Kernel_traits::TiledMmaO tiled_mma_o;
auto thr_mma_o = tiled_mma_o.get_thread_slice(tidx);
Tensor tOrVt = thr_mma_o.partition_fragment_B(sVt); // (MMA, MMA_K,MMA_N)
Tensor tOrO = partition_fragment_C(tiled_mma_o, Shape<Int<kBlockM>, Int<kHeadDimV>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
clear(tOrO);
flash::Softmax<2 * size<1>(tOrO)> softmax;
int warp_group_idx = cutlass::canonical_warp_group_idx();
if (warp_group_idx == 0) {
typename Kernel_traits::TiledMma tiled_mma;
auto thr_mma = tiled_mma.get_thread_slice(tidx);
Tensor tSrQ = thr_mma.partition_fragment_A(sQ); // (MMA,MMA_M,MMA_K)
Tensor tSrK = thr_mma.partition_fragment_B(sK); // (MMA,MMA_N,MMA_K)
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need masking on S for the very last block when K and V has length not multiple of kBlockN.
// We also need masking on S if it's causal, for the last ceil_div(kBlockM, kBlockN) blocks.
// We will have at least 1 "masking" iteration.
// If not even_N, then seqlen_k might end in the middle of a block. In that case we need to
// mask 2 blocks (e.g. when kBlockM == kBlockN), not just 1.
constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
#pragma unroll 1
for (int masking_step = n_masking_steps; n_block >= n_block_min; --masking_step, --n_block) {
__syncthreads();
Tensor tSrS = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}); // ((MMA=4, X), MMA_M, MMA_N=1)
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma, tSrQ, tSrK, tSrS);
const bool is_masking_step = masking_step > 0;
const bool is_first_masking_step = masking_step == n_masking_steps;
if (is_masking_step) {
Tensor cS = make_identity_tensor(Shape<Int<kBlockM>, Int<kBlockN>>{});
Tensor tScS = thr_mma.partition_C(cS);
#pragma unroll
for (int i = 0; i < size(tSrS); ++i) {
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) tSrS(i) = -INFINITY;
} else {
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int row = int(get<0>(tScS(i)));
int col_limit_right = seqlen_k - 1 - n_block * kBlockN - (params.seqlen_q - 1 - (m_block * kBlockM + row)) / params.ngroups;
if (int(get<1>(tScS(i))) > col_limit_right) tSrS(i) = -INFINITY;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
Tensor scale_o = is_first_masking_step
? softmax.template softmax</*Is_first=*/true, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: is_masking_step ?
softmax.template softmax</*Is_first=*/false, /*Check_inf=*/Is_causal>(tSrS, params.scale_softmax_log2)
: softmax.template softmax</*Is_first=*/false, /*Check_inf=*//*Is_local=*/false>(tSrS, params.scale_softmax_log2);
Tensor rP = flash::convert_type<Element>(tSrS);
cute::copy(rP, tPsP);
cute::copy(scale_o, tScale_osScale_o);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SReady));
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tSrK.data() = tSrK.data() + sK_offset / 8;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cute::copy(softmax.row_max, tRow_maxsRow_max);
cute::copy(softmax.row_sum, tRow_sumsRow_sum);
cutlass::arch::NamedBarrier::arrive(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
} else {
const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
int cur_block_table = __ldg(&block_table[n_block]);
const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
Shape<Int<kBlockM>, Int<kHeadDim>>{},
make_stride(params.q_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_Q;
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx - kNThreadsS);
Tensor tQgQ = gmem_thr_copy_Q.partition_S(gQ);
Tensor tQsQ = gmem_thr_copy_Q.partition_D(sQ);
Tensor cQ = make_identity_tensor(make_shape(size<0>(sQ), size<1>(sQ))); // (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
params.seqlen_q - m_block * kBlockM);
const index_t row_offset_k = (bidh / params.h_h_k_ratio) * params.k_head_stride;
Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.k_row_stride, _1{}));
typename Kernel_traits::GmemTiledCopy gmem_tiled_copy_K;
auto gmem_thr_copy_K = gmem_tiled_copy_K.get_thread_slice(tidx - kNThreadsS);
Tensor tKgK = gmem_thr_copy_K.partition_S(gK);
Tensor tKsK = gmem_thr_copy_K.partition_D(sK);
Tensor cK = make_identity_tensor(make_shape(size<0>(sK), size<1>(sK))); // (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor tKcK = gmem_thr_copy_K.partition_S(cK); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tKpK = make_tensor<bool>(make_shape(size<2>(tKsK)));
if (n_block % 2 == 1) {
// Double buffer for sK
constexpr int sK_offset = size(sK);
tKsK.data() = tKsK.data() + sK_offset;
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
// We need to clear the sK smem tiles because K is V.
const index_t offset_k = cur_block_table * params.k_batch_stride;
tKgK.data() = tKgK.data() + offset_k;
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true, /*Clear_OOB_MN=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK,
seqlen_k - n_block * kBlockN);
tKgK.data() = tKgK.data() + -offset_k;
cute::cp_async_fence();
if (n_block - 1 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 1]);
}
#pragma unroll 1
for (; n_block >= n_block_min; --n_block) {
flash::cp_async_wait<0>();
__syncthreads();
if (n_block - 1 >= n_block_min) {
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tKsK.data() = tKsK.data() + sK_offset;
const index_t offset_k = cur_block_table * params.k_batch_stride;
tKgK.data() = tKgK.data() + offset_k;
flash::copy</*Is_even_MN=*/true, /*Is_even_K=*/true>(gmem_tiled_copy_K, tKgK, tKsK, tKcK, tKpK);
tKgK.data() = tKgK.data() + -offset_k;
cute::cp_async_fence();
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SReady));
if (n_block - 2 >= n_block_min) {
cur_block_table = __ldg(&block_table[n_block - 2]);
}
typename Kernel_traits::TiledMma tiled_mma;
auto tSrS_layout = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{}).layout();
Tensor rP = make_tensor<Element>(tSrS_layout);
Tensor scale_o = make_tensor<float>(Shape<_2>{});
cute::copy(tScale_osScale_o, scale_o);
cute::copy(tPsP, rP);
flash::rescale_o(tOrO, scale_o);
Tensor tOrP = make_tensor(rP.data(), flash::convert_layout_acc_Aregs<Kernel_traits::TiledMma>(rP.layout()));
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma_o, tOrP, tOrVt, tOrO);
// Double buffer for sK
const int sK_offset = n_block % 2 == 0 ? size(sK) : -size(sK);
tOrVt.data() = tOrVt.data() + sK_offset / 8;
}
cutlass::arch::NamedBarrier::sync(kNThreads, static_cast<int>(NamedBarriers::SoftmaxReady));
cute::copy(tRow_maxsRow_max, softmax.row_max);
cute::copy(tRow_sumsRow_sum, softmax.row_sum);
}
if (NoSplit)
store<Kernel_traits, false>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
else
store<Kernel_traits, true>(params, bidb, bidh, m_block, n_split_idx, shared_storage, tOrO, softmax);
}
template<typename Kernel_traits, bool Is_causal, typename SharedStorage>
__global__ void __launch_bounds__(Kernel_traits::kNThreads, 1, 1)
flash_fwd_splitkv_mla_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
constexpr int kBlockN = Kernel_traits::kBlockN;
const int m_block = blockIdx.x;
const int bidh = blockIdx.y;
const int partition_idx = blockIdx.z;
extern __shared__ char shared_memory[];
auto &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr + partition_idx * TileSchedulerMetaDataSize;
int4 tile_scheduler_metadata = __ldg(reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr));
int begin_idx = tile_scheduler_metadata.x;
int begin_seqlen = tile_scheduler_metadata.y;
int end_idx = tile_scheduler_metadata.z;
int end_seqlen = tile_scheduler_metadata.w;
if (begin_idx >= params.b) return;
int begin_n_split_idx = __ldg(tile_scheduler_metadata_ptr + 4);
#pragma unroll 1
for (int batch_id = begin_idx; batch_id <= end_idx; ++batch_id) {
const int n_split_idx = batch_id == begin_idx ? begin_n_split_idx : 0;
const int seqlen_k = __ldg(params.cu_seqlens_k + batch_id);
const int n_block_min = batch_id == begin_idx ? begin_seqlen / kBlockN : 0;
const int n_block_max = batch_id == end_idx ? cute::ceil_div(end_seqlen, kBlockN) : cute::ceil_div(seqlen_k, kBlockN);
const bool NoSplit = n_block_min == 0 && n_block_max == cute::ceil_div(seqlen_k, kBlockN);
if (batch_id > begin_idx) {
__syncthreads(); // Barrier between two tiles.
}
flash::compute_attn_1rowblock_splitkv_mla<Kernel_traits, Is_causal>(params, batch_id, bidh, m_block, n_split_idx, seqlen_k, n_block_min, n_block_max, NoSplit, shared_storage);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Element, typename ElementAccum, typename index_t, int kHeadDimV, int kMaxSplits>
__global__ void __launch_bounds__(256, 1, 1)
flash_fwd_splitkv_mla_combine_kernel(__grid_constant__ const Flash_fwd_mla_params params) {
constexpr int kNThreads = 128;
const int tidx = threadIdx.x;
const int bidx = blockIdx.x;
const int hs = params.h * params.seqlen_q;
const int batch_idx = bidx / hs;
const int hs_idx = bidx % hs;
const int split_offset = __ldg(params.num_splits_ptr + batch_idx);
const int actual_num_splits = __ldg(params.num_splits_ptr + batch_idx + 1) - split_offset;
FLASH_DEVICE_ASSERT(actual_num_splits <= kMaxSplits);
if (actual_num_splits == 1) return;
__shared__ ElementAccum sLseScale[kMaxSplits];
const index_t row_offset_lseaccum = split_offset * hs + hs_idx;
const index_t row_offset_lse = bidx;
Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lseaccum_ptr) + row_offset_lseaccum),
Shape<Int<kMaxSplits>>{}, make_stride(hs));
Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.softmax_lse_ptr) + row_offset_lse),
Shape<_1>{}, Stride<_1>{});
int warp_idx = cutlass::canonical_warp_idx_sync();
if (warp_idx == 0) {
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
float local_lse[kNLsePerThread];
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
local_lse[i] = split < actual_num_splits ? gLSEaccum(split) : -INFINITY;
}
float max_lse = -INFINITY;
for (int i = 0; i < kNLsePerThread; ++i) max_lse = max(max_lse, local_lse[i]);
for (int offset = 16; offset >= 1; offset /= 2) max_lse = max(max_lse, __shfl_xor_sync(uint32_t(-1), max_lse, offset));
max_lse = max_lse == -INFINITY ? 0.0f : max_lse; // In case all local LSEs are -inf
float sum_lse = 0;
for (int i = 0; i < kNLsePerThread; ++i) sum_lse = sum_lse + expf(local_lse[i] - max_lse);
for (int offset = 16; offset >= 1; offset /= 2) sum_lse = sum_lse + __shfl_xor_sync(uint32_t(-1), sum_lse, offset);
float global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? INFINITY : logf(sum_lse) + max_lse;
if (tidx == 0) gLSE(0) = global_lse;
for (int i = 0; i < kNLsePerThread; ++i) {
const int split = i * 32 + tidx;
if (split < actual_num_splits) sLseScale[split] = expf(local_lse[i] - global_lse);
}
}
__syncthreads();
static_assert(kHeadDimV % kNThreads == 0);
constexpr int Elements = kHeadDimV / kNThreads;
const index_t row_offset_oaccum = (split_offset * hs + hs_idx) * kHeadDimV;
Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(params.oaccum_ptr) + row_offset_oaccum),
Shape<Int<kHeadDimV>>{}, Stride<_1>{});
using GmemTiledCopyOaccum = decltype(make_tiled_copy(
Copy_Atom<AutoVectorizingCopyWithAssumedAlignment<128>, ElementAccum>{},
Layout<Shape<Int<kNThreads>>>{},
Layout<Shape<Int<Elements>>>{}));
GmemTiledCopyOaccum gmem_tiled_copy_Oaccum;
auto gmem_thr_copy_Oaccum = gmem_tiled_copy_Oaccum.get_thread_slice(tidx);
Tensor tOgOaccum = gmem_thr_copy_Oaccum.partition_S(gOaccum);
Tensor tOrOaccum = make_tensor<ElementAccum>(shape(tOgOaccum));
Tensor tOrO = make_tensor<ElementAccum>(shape(tOgOaccum));
clear(tOrO);
for (int split = 0; split < actual_num_splits; ++split) {
cute::copy(tOgOaccum, tOrOaccum);
ElementAccum lse_scale = sLseScale[split];
for (int i = 0; i < size(tOrO); ++i) {
tOrO(i) += lse_scale * tOrOaccum(i);
}
tOgOaccum.data() = tOgOaccum.data() + hs * kHeadDimV;
}
Tensor rO = flash::convert_type<Element>(tOrO);
const int head_idx = (bidx - batch_idx * hs) / params.seqlen_q;
const int row = bidx - batch_idx * hs - head_idx * params.seqlen_q;
auto o_ptr = reinterpret_cast<Element *>(params.o_ptr) + batch_idx * params.o_batch_stride + head_idx * params.o_head_stride + row * params.o_row_stride;
Tensor gO = make_tensor(make_gmem_ptr(o_ptr + tidx * Elements), Shape<Int<decltype(size<0>(rO))::value>>{}, Stride<_1>{});
cute::copy(rO, gO);
}
} // namespace flash
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, typename SharedStorage>
void run_flash_splitkv_fwd_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.page_block_size == Kernel_traits::kBlockN);
const int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM);
BOOL_SWITCH(params.is_causal, Is_causal, [&] {
auto kernel = &flash::flash_fwd_splitkv_mla_kernel<Kernel_traits, Is_causal, SharedStorage>;
constexpr size_t smem_size = sizeof(SharedStorage);
CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
kernel<<<dim3(num_m_block, params.h, params.num_sm_parts), Kernel_traits::kNThreads, smem_size, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
dim3 grid_combine(params.b * params.h * params.seqlen_q);
MLA_NUM_SPLITS_SWITCH(params.num_sm_parts, kMaxSplits, [&] {
auto combine_kernel = &flash::flash_fwd_splitkv_mla_combine_kernel<
typename Kernel_traits::Element, typename Kernel_traits::ElementAccum, typename Kernel_traits::index_t, Kernel_traits::kHeadDimV, kMaxSplits>;
combine_kernel<<<grid_combine, 128, 0, stream>>>(params);
});
CHECK_CUDA_KERNEL_LAUNCH();
}
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream) {
static_assert(Headdim == 576);
FLASH_ASSERT(params.d_v == 512);
FLASH_ASSERT(params.k_ptr == params.v_ptr); // Shared_KV
using Kernel_traits = Flash_fwd_kernel_traits_mla<576, 64, 64, 8, T, 512>;
run_flash_splitkv_fwd_mla<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>>(params, stream);
}
static constexpr int MaxBatchSize = 4096;
__global__ void __launch_bounds__(256, 1, 1)
get_mla_metadata_kernel(__grid_constant__ const Mla_metadata_params params) {
int *seqlens_k_ptr = params.seqlens_k_ptr;
int *tile_scheduler_metadata_ptr = params.tile_scheduler_metadata_ptr;
int *num_splits_ptr = params.num_splits_ptr;
int batch_size = params.batch_size;
int block_size_n = params.block_size_n;
int fixed_overhead_num_blocks = params.fixed_overhead_num_blocks;
int num_sm_parts = params.num_sm_parts;
__shared__ int num_blocks_shared[MaxBatchSize];
__shared__ int num_splits_shared[MaxBatchSize];
int total_num_blocks = 0;
for (int i = threadIdx.x; i < batch_size; i += 32) {
int num_blocks = cutlass::ceil_div(seqlens_k_ptr[i], block_size_n);
total_num_blocks += num_blocks + fixed_overhead_num_blocks;
num_blocks_shared[i] = num_blocks;
}
for (int offset = 16; offset >= 1; offset /= 2) {
total_num_blocks += __shfl_xor_sync(uint32_t(-1), total_num_blocks, offset);
}
__syncwarp();
if (threadIdx.x == 0) {
int payload = cutlass::ceil_div(total_num_blocks, num_sm_parts) + fixed_overhead_num_blocks;
int now_idx = 0, now_block = 0, now_n_split_idx = 0, cum_num_splits = 0;
num_splits_shared[0] = 0;
for (int i = 0; i < num_sm_parts; ++i) {
int tile_scheduler_metadata0[4], tile_scheduler_metadata1;
tile_scheduler_metadata0[0] = now_idx;
tile_scheduler_metadata0[1] = now_block * block_size_n;
tile_scheduler_metadata1 = now_n_split_idx;
int remain_payload = payload;
while (now_idx < batch_size) {
int num_blocks = num_blocks_shared[now_idx];
int now_remain_blocks = num_blocks - now_block;
if (remain_payload >= now_remain_blocks + fixed_overhead_num_blocks) {
cum_num_splits += now_n_split_idx + 1;
num_splits_shared[now_idx + 1] = cum_num_splits;
remain_payload -= now_remain_blocks + fixed_overhead_num_blocks;
++now_idx;
now_block = 0;
now_n_split_idx = 0;
} else {
if (remain_payload - fixed_overhead_num_blocks > 0) {
now_block += remain_payload - fixed_overhead_num_blocks;
++now_n_split_idx;
remain_payload = 0;
}
break;
}
}
tile_scheduler_metadata0[2] = now_block > 0 ? now_idx : now_idx - 1;
tile_scheduler_metadata0[3] = now_block > 0 ? now_block * block_size_n : seqlens_k_ptr[now_idx - 1];
*reinterpret_cast<int4 *>(tile_scheduler_metadata_ptr + i * TileSchedulerMetaDataSize) = *reinterpret_cast<int4 *>(tile_scheduler_metadata0);
tile_scheduler_metadata_ptr[i * TileSchedulerMetaDataSize + 4] = tile_scheduler_metadata1;
}
FLASH_DEVICE_ASSERT(now_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncwarp();
for (int i = threadIdx.x; i <= batch_size; i += 32) {
num_splits_ptr[i] = num_splits_shared[i];
}
}
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream) {
FLASH_ASSERT(params.batch_size < MaxBatchSize);
get_mla_metadata_kernel<<<1, 32, 0, stream>>>(params);
CHECK_CUDA_KERNEL_LAUNCH();
}

View File

@@ -0,0 +1,65 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/flash_mla.h
#pragma once
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Flash_fwd_mla_params {
using index_t = int64_t;
int b, seqlen_q, d, d_v;
int h, h_h_k_ratio, ngroups;
bool is_causal;
float scale_softmax, scale_softmax_log2;
int *__restrict__ cu_seqlens_k;
void *__restrict__ q_ptr;
void *__restrict__ k_ptr;
void *__restrict__ v_ptr;
void *__restrict__ o_ptr;
void *__restrict__ softmax_lse_ptr;
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t o_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t o_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;
index_t o_head_stride;
int *__restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int *__restrict__ tile_scheduler_metadata_ptr;
int num_sm_parts;
int *__restrict__ num_splits_ptr;
void *__restrict__ softmax_lseaccum_ptr;
void *__restrict__ oaccum_ptr;
};
static constexpr int TileSchedulerMetaDataSize = 8;
// [begin_idx, begin_seqlen, end_idx, end_seqlen, begin_n_split_idx, _, _, _]
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T, int Headdim>
void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, cudaStream_t stream);
struct Mla_metadata_params {
int *__restrict__ seqlens_k_ptr;
int *__restrict__ tile_scheduler_metadata_ptr;
int *__restrict__ num_splits_ptr;
int batch_size;
int block_size_n;
int fixed_overhead_num_blocks;
int num_sm_parts;
};
void get_mla_metadata_func(Mla_metadata_params &params, cudaStream_t stream);

View File

@@ -0,0 +1,17 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/named_barrier.h
#pragma once
#include "cutlass/barrier.h"
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum class NamedBarriers {
SReady = 1,
SoftmaxReady = 2,
};
} // flash

View File

@@ -0,0 +1,197 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/softmax.h
#pragma once
#include <cmath>
#include <cute/tensor.hpp>
#include <cutlass/numeric_types.h>
#include "utils.h"
namespace flash {
using namespace cute;
////////////////////////////////////////////////////////////////////////////////////////////////////
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); mi++) {
summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
summary(mi) = op(summary(mi), tensor(mi, ni));
}
}
}
template<typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void quad_allreduce_(Tensor<Engine0, Layout0> &dst, Tensor<Engine1, Layout1> &src, Operator &op) {
CUTE_STATIC_ASSERT_V(size(dst) == size(src));
#pragma unroll
for (int i = 0; i < size(dst); i++){
dst(i) = Allreduce<4>::run(src(i), op);
}
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1, typename Operator>
__device__ __forceinline__ void reduce_(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &summary, Operator &op) {
thread_reduce_<zero_init>(tensor, summary, op);
quad_allreduce_(summary, summary, op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_max(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &max){
MaxOp<float> max_op;
reduce_<zero_init>(tensor, max, max_op);
}
template<bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__device__ __forceinline__ void reduce_sum(Tensor<Engine0, Layout0> const& tensor, Tensor<Engine1, Layout1> &sum){
SumOp<float> sum_op;
thread_reduce_<zero_init>(tensor, sum, sum_op);
}
// Apply the exp to all the elements.
template <bool Scale_max=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ auto scale_apply_exp2(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> const &max, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * (Scale_max ? scale : float(M_LOG2E));
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// The following macro will disable the use of fma.
// See: https://github.com/pytorch/pytorch/issues/121558 for more details
// This macro is set in PyTorch and not FlashAttention
#ifdef UNFUSE_FMA
tensor(mi, ni) = exp2f(__fmul_rn(tensor(mi, ni), scale) - max_scaled);
#else
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
#endif
}
}
return tensor;
}
// Apply the exp to all the elements.
template <bool zero_init=true, typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tensor<Engine1, Layout1> &max, Tensor<Engine1, Layout1> &sum, const float scale) {
static_assert(Layout0::rank == 2, "Only support 2D Tensor");
static_assert(Layout1::rank == 1, "Only support 1D Tensor");
CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor));
#pragma unroll
for (int mi = 0; mi < size<0>(tensor); ++mi) {
MaxOp<float> max_op;
max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0));
#pragma unroll
for (int ni = 1; ni < size<1>(tensor); ni++) {
max(mi) = max_op(max(mi), tensor(mi, ni));
}
max(mi) = Allreduce<4>::run(max(mi), max_op);
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale;
sum(mi) = 0;
#pragma unroll
for (int ni = 0; ni < size<1>(tensor); ++ni) {
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled);
sum(mi) += tensor(mi, ni);
}
SumOp<float> sum_op;
sum(mi) = Allreduce<4>::run(sum(mi), sum_op);
}
}
template<typename Tensor0, typename Tensor1>
__forceinline__ __device__ void rescale_o(Tensor0 &acc_o, Tensor1 &scale_o) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
#pragma unroll
for (int mi = 0; mi < size(scale_o); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale_o(mi); }
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <int kNRows>
struct Softmax {
using TensorT = decltype(make_tensor<float>(Shape<Int<kNRows>>{}));
TensorT row_max, row_sum;
__forceinline__ __device__ Softmax() {};
template<bool Is_first, bool Check_inf=false, typename Tensor0>
__forceinline__ __device__ TensorT softmax(Tensor0 &acc_s, float softmax_scale_log2) {
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert(decltype(size<0>(scores))::value == kNRows);
TensorT scale_o;
clear(scale_o);
if (Is_first) {
flash::template reduce_max</*zero_init=*/true>(scores, row_max);
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
flash::reduce_sum</*zero_init=*/true>(scores, row_sum);
} else {
Tensor scores_max_prev = make_fragment_like(row_max);
cute::copy(row_max, scores_max_prev);
flash::template reduce_max</*zero_init=*/false>(scores, row_max);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
#pragma unroll
for (int mi = 0; mi < size(row_max); ++mi) {
float scores_max_cur = !Check_inf
? row_max(mi)
: (row_max(mi) == -INFINITY ? 0.0f : row_max(mi));
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
scale_o(mi) = scores_scale;
row_sum(mi) *= scores_scale;
}
flash::scale_apply_exp2(scores, row_max, softmax_scale_log2);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash::reduce_sum</*zero_init=*/false>(scores, row_sum);
}
return scale_o;
};
template<bool Is_dropout=false, bool Split=false, typename Tensor0>
__forceinline__ __device__ TensorT normalize_softmax_lse(Tensor0 &acc_o, float softmax_scale, float rp_dropout=1.0) {
SumOp<float> sum_op;
quad_allreduce_(row_sum, row_sum, sum_op);
TensorT lse = make_fragment_like(row_sum);
// Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for (int mi = 0; mi < size<0>(acc_o_rowcol); ++mi) {
float sum = row_sum(mi);
float inv_sum = (sum == 0.f || sum != sum) ? 1.f : 1.f / sum;
lse(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * softmax_scale + __logf(sum);
float scale = !Is_dropout ? inv_sum : inv_sum * rp_dropout;
#pragma unroll
for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scale; }
}
return lse;
};
};
} // namespace flash

View File

@@ -0,0 +1,67 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/static_switch.h
#pragma once
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while(0)
#define CHECK_CUDA_KERNEL_LAUNCH() CHECK_CUDA(cudaGetLastError())
#define FLASH_ASSERT(cond) \
do { \
if (not (cond)) { \
fprintf(stderr, "Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
exit(1); \
} \
} while(0)
#define FLASH_DEVICE_ASSERT(cond) \
do { \
if (not (cond)) { \
printf("Assertion failed (%s:%d): %s\n", __FILE__, __LINE__, #cond); \
asm("trap;"); \
} \
} while(0)
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
#define MLA_NUM_SPLITS_SWITCH(NUM_SPLITS, NAME, ...) \
[&] { \
if (NUM_SPLITS <= 32) { \
constexpr static int NAME = 32; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 64) { \
constexpr static int NAME = 64; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 96) { \
constexpr static int NAME = 96; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 128) { \
constexpr static int NAME = 128; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
}()

View File

@@ -0,0 +1,238 @@
// Adapted from https://github.com/deepseek-ai/FlashMLA/blob/main/csrc/utils.h
#pragma once
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
#include <cuda_bf16.h>
#include <cute/tensor.hpp>
#include <cutlass/array.h>
#include <cutlass/cutlass.h>
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
namespace flash {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct MaxOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; }
};
template <>
struct MaxOp<float> {
// This is slightly faster
__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
struct SumOp {
__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; }
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<int THREADS>
struct Allreduce {
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
constexpr int OFFSET = THREADS / 2;
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
return Allreduce<OFFSET>::run(x, op);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template<>
struct Allreduce<2> {
template<typename T, typename Operator>
static __device__ __forceinline__ T run(T x, Operator &op) {
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
return x;
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool zero_init=false, int wg_wait=0, bool arrive=true, bool commit=true, typename Tensor0, typename Tensor1, typename Tensor2, typename TiledMma>
__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) {
constexpr bool Is_RS = !cute::is_base_of<cute::GMMA::DescriptorIterator, typename TiledMma::FrgTypeA>::value;
// Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const
if constexpr (Is_RS) { cute::warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
warpgroup_fence_operand(tCrC);
if constexpr (arrive) {
warpgroup_arrive();
}
if constexpr (zero_init) {
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
} else {
// cute::gemm(tiled_mma, tCrA, tCrB, tCrC);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) {
cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
}
}
if constexpr (commit) {
warpgroup_commit_batch();
}
if constexpr (wg_wait >= 0) { warpgroup_wait<wg_wait>(); }
warpgroup_fence_operand(tCrC);
if constexpr (Is_RS) { warpgroup_fence_operand(const_cast<Tensor0 &>(tCrA)); }
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N))
template<bool Transposed=false, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout0 acc_layout) {
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = acc_layout;
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N)
if constexpr (!Transposed) {
return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l)));
} else {
return make_layout(make_layout(get<0, 0>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l)));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2)
// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8.
// For SM90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N))
// For SM90, FP8, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((4, 2, 2), MMA_M, (N / 32, MMA_N))
template<typename MMA_Traits, typename Layout0>
__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout0 acc_layout) {
using X = Underscore;
if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90
static_assert(decltype(size<0, 0>(acc_layout))::value == 2);
static_assert(decltype(size<0, 1>(acc_layout))::value == 2);
static_assert(decltype(rank(acc_layout))::value == 3);
static_assert(decltype(rank(get<0>(acc_layout)))::value == 3);
if constexpr (sizeof(typename MMA_Traits::ValTypeA) == 2) {
auto l = logical_divide(get<0, 2>(acc_layout), Tile<_2>{}); // ((2, N / 16))
return make_layout(make_layout(get<0, 0>(acc_layout), get<0, 1>(acc_layout), get<0, 0>(l)), get<1>(acc_layout), coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
} else {
static_assert(sizeof(typename MMA_Traits::ValTypeA) == 1);
static_assert(decltype(stride<0, 0>(acc_layout))::value == 1);
static_assert(decltype(stride<0, 1>(acc_layout))::value == 2);
auto l = logical_divide(get<0, 2>(acc_layout), Tile<Layout<Shape<_2, _2>>>{}); // (((2, 2), N / 32))
// This combines the first two modes (<0, 0> and <0, 1>) into one mode.
// Will require register shuffling later to be correct.
return make_layout(make_layout(Layout<_4>{}, get<0, 0, 0>(l), get<0, 0, 1>(l)),
get<1>(acc_layout),
coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout)))); // ((4, 2, 2), MMA_M, N / 32 * MMA_N)
// This combination is right but doesn't work with register shuffling.
// return make_layout(make_layout(coalesce(make_layout(get<0, 0>(acc_layout), get<0, 0, 0>(l))), get<0, 1>(acc_layout), get<0, 0, 1>(l)),
// get<1>(acc_layout),
// coalesce(make_layout(get<0, 1>(l), get<2>(acc_layout))));
}
} else { // SM80
static_assert(decltype(size<0>(acc_layout))::value == 4);
static_assert(decltype(rank(acc_layout))::value == 3);
constexpr int mma_shape_K = get<2>(typename MMA_Traits::Shape_MNK{});
static_assert(mma_shape_K == 8 || mma_shape_K == 16);
if constexpr (mma_shape_K == 8) {
return acc_layout;
} else {
auto l = logical_divide(acc_layout, Shape<X, X, _2>{}); // (4, MMA_M, (2, MMA_N / 2)))
return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l));
}
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename To_type, typename Engine, typename Layout>
__forceinline__ __device__ auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
// HACK: this requires tensor to be "contiguous"
auto frag = convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// Blocks until all but N previous cp.async.commit_group operations have committed.
// This differs from cute::cp_async_wait in that when N = 0 we don't call cp.async.wait_all
// (which is equivalent to commit_group then wait_group 0).
// Instead we just call cp.async.wait_group 0, which is slightly faster.
// https://github.com/NVIDIA/cutlass/blob/master/include/cute/arch/copy_sm80.hpp#L113
template <int N>
CUTE_HOST_DEVICE
void cp_async_wait() {
#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED)
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
typename TiledCopy, typename Engine0, typename Layout0, typename Engine1, typename Layout1,
typename Engine2, typename Layout2, typename Engine3, typename Layout3>
__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layout0> const &S,
Tensor<Engine1, Layout1> &D, Tensor<Engine2, Layout2> const &identity_MN,
Tensor<Engine3, Layout3> const &predicate_K, const int max_MN=0) {
CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{});
CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{});
CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA
CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M
CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K
// There's no case where !Clear_OOB_K && Clear_OOB_MN
static_assert(!(Clear_OOB_MN && !Clear_OOB_K));
#pragma unroll
for (int m = 0; m < size<1>(S); ++m) {
if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) {
#pragma unroll
for (int k = 0; k < size<2>(S); ++k) {
if (Is_even_K || predicate_K(k)) {
cute::copy(tiled_copy, S(_, m, k), D(_, m, k));
} else if (Clear_OOB_K) {
cute::clear(D(_, m, k));
}
}
} else if (Clear_OOB_MN) {
cute::clear(D(_, m, _));
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace flash

View File

@@ -146,6 +146,7 @@ foreach(EXAMPLE
64_ada_fp8_gemm_grouped
65_distributed_gemm
67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling
68_hopper_flash_mla
69_hopper_mixed_dtype_grouped_gemm
70_blackwell_gemm
71_blackwell_gemm_with_collective_builder