Add unified attention (42_unified_attention)

Squashed from aghamari/unified-attention-decode-opt branch.

CK tile paged-KV attention kernel optimized for decode with 4-tier
dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid,
head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with
block_size=32.

Made-with: Cursor
This commit is contained in:
root
2026-04-01 16:24:53 +00:00
parent ec2db01e4a
commit cd7ba6e2e8
7 changed files with 19 additions and 455 deletions

View File

@@ -36,7 +36,7 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {32: 32, 64: 64, 96: 128, 128: 128, 256: 256}
SUPPORTED_PAGE_SIZE = [1, 16, 1024]
SUPPORTED_PAGE_SIZE = [1, 16, 32, 1024]
SUPPORTED_KV_MEMORY_LAYOUT = ["vectorized", "linear"]
SUPPORTED_KV_LOOKUP_TABLE = ["vllm", "sglang"]
KV_MEMORY_LAYOUT_ENUM_MAP = {

View File

@@ -1,24 +1,11 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
add_executable(tile_example_topk_softmax_decode
topk_softmax_decode.cpp
topk_softmax_decode_api.cpp
topk_softmax_api.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp)
target_include_directories(tile_example_topk_softmax_decode PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/)
target_compile_definitions(tile_example_topk_softmax_decode PRIVATE
CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1
MOE_SORTING_FMOE_2D_BUF=1)
target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})

View File

@@ -1,314 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <time.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "topk_softmax_decode_api.hpp"
#include "topk_softmax_api.hpp"
#include "moe_sorting_api.hpp"
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
#endif
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
// CPU reference: softmax -> topk -> moe_sorting
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool reference_fused(const ck_tile::HostTensor<InputType>& x_host,
ck_tile::index_t topk,
ck_tile::index_t num_experts,
ck_tile::index_t unit_size,
ck_tile::HostTensor<IndexType>& ref_sorted_ids,
ck_tile::HostTensor<WeightType>& ref_sorted_weights,
ck_tile::HostTensor<IndexType>& ref_sorted_expert_ids,
ck_tile::index_t& ref_unit_cnt)
{
auto probs = ck_tile::reference_softmax<InputType, WeightType, WeightType>(x_host);
ck_tile::HostTensor<WeightType> topk_vals({1, topk});
ck_tile::HostTensor<IndexType> topk_idxs({1, topk});
ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk);
ck_tile::HostTensor<IndexType> local_expert_mask({num_experts});
ref_unit_cnt = 0;
ck_tile::reference_moe_sorting<WeightType, IndexType>(
topk_idxs,
topk_vals,
local_expert_mask,
ref_sorted_ids,
ref_sorted_weights,
ref_sorted_expert_ids,
ref_unit_cnt,
num_experts,
unit_size,
1,
false,
true);
return true;
}
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "do CPU validation")
.insert("pr_i", "bf16", "input data type: fp16/bf16")
.insert("e", "128", "number of experts")
.insert("k", "8", "topk")
.insert("unit", "32", "unit_size (block_size_M)")
.insert("model_dim", "7168", "model dimension for moe_buf zeroing")
.insert("seed", "-1", "random seed, -1 = random")
.insert("warmup", "5", "warmup iterations")
.insert("repeat", "20", "benchmark iterations");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
bool run_test(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string pr = args.get_str("pr_i");
int experts = args.get_int("e");
int topk = args.get_int("k");
int unit_size = args.get_int("unit");
int model_dim = args.get_int("model_dim");
int seed = args.get_int("seed");
int warmup = args.get_int("warmup");
int repeat = args.get_int("repeat");
if(seed < 0)
seed = std::time(nullptr);
if(topk > experts)
{
printf("topk %d > experts %d, skip\n", topk, experts);
return false;
}
int tokens = 1;
int max_num_tokens_padded = topk + experts * unit_size - topk;
int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size;
// Host tensors
ck_tile::HostTensor<InputType> x_host({1, experts});
{
auto rng = ck_tile::FillUniformDistribution_Unique<InputType>{
-5.f, 5.f, static_cast<uint32_t>(seed)};
ck_tile::HostTensor<InputType> row({experts});
rng(row);
std::copy(row.begin(), row.end(), x_host.begin());
}
// ---------- Device buffers (shared) ----------
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
// ---------- Fused kernel buffers ----------
ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType));
ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType));
{
std::vector<float> ones(model_dim, 1.0f);
fused_moe_buf.ToDevice(ones.data());
}
// ---------- Two-kernel baseline buffers ----------
ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType));
ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType));
ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType));
ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType));
int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0);
ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1);
if(ws_size > 0)
base_ws.SetZero();
ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat};
// ====================== Fused kernel ======================
topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"};
topk_softmax_decode_kargs fused_karg{
x_dev.GetDeviceBuffer(),
experts,
topk,
experts,
true,
fused_sorted_ids.GetDeviceBuffer(),
fused_sorted_weights.GetDeviceBuffer(),
fused_sorted_expert_ids.GetDeviceBuffer(),
fused_num_valid.GetDeviceBuffer(),
fused_moe_buf.GetDeviceBuffer(),
unit_size,
model_dim,
static_cast<int>(sizeof(WeightType))};
float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc);
// ============= Two-kernel baseline: topk + moe_sorting =============
topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"};
topk_softmax_kargs ts_karg{
x_dev.GetDeviceBuffer(),
topk_w_dev.GetDeviceBuffer(),
topk_i_dev.GetDeviceBuffer(),
tokens,
experts,
topk,
experts,
topk};
moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0};
moe_sorting_args ms_arg{
topk_i_dev.GetDeviceBuffer(),
topk_w_dev.GetDeviceBuffer(),
nullptr,
nullptr,
base_sorted_ids.GetDeviceBuffer(),
base_sorted_weights.GetDeviceBuffer(),
base_sorted_expert_ids.GetDeviceBuffer(),
base_num_valid.GetDeviceBuffer(),
base_moe_buf.GetDeviceBuffer(),
ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr,
tokens,
unit_size,
experts,
topk,
model_dim,
static_cast<int>(sizeof(WeightType))};
// Time the two kernels together using launch_kernel with two lambdas
auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1};
float ms_baseline = ck_tile::launch_kernel(
sc,
[&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); },
[&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); });
float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0;
printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx",
pr.c_str(),
experts,
topk,
unit_size,
ms_fused,
ms_baseline,
speedup);
if(ms_fused < 0 || ms_baseline < 0)
{
printf(" (not supported)\n");
return false;
}
// ====================== Validation ======================
bool pass = true;
if(validate)
{
ck_tile::HostTensor<IndexType> sorted_ids_host({max_num_tokens_padded});
ck_tile::HostTensor<WeightType> sorted_weights_host({max_num_tokens_padded});
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_num_m_blocks});
ck_tile::HostTensor<IndexType> num_valid_host({2});
std::vector<float> moe_buf_host(model_dim);
fused_sorted_ids.FromDevice(sorted_ids_host.data());
fused_sorted_weights.FromDevice(sorted_weights_host.data());
fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data());
fused_num_valid.FromDevice(num_valid_host.data());
fused_moe_buf.FromDevice(moe_buf_host.data());
ck_tile::HostTensor<IndexType> ref_sorted_ids({max_num_tokens_padded});
ck_tile::HostTensor<WeightType> ref_sorted_weights({max_num_tokens_padded});
ck_tile::HostTensor<IndexType> ref_sorted_expert_ids({max_num_m_blocks});
IndexType sentinel = static_cast<uint32_t>((1 & 0x00ffffff) | ((topk & 0xff) << 24));
std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel);
std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0));
std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1);
ck_tile::index_t ref_unit_cnt = 0;
reference_fused<InputType, WeightType, IndexType>(
x_host, topk, experts, unit_size,
ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt);
int num_valid_padded = num_valid_host(0);
int num_valid_tokens = num_valid_host(1);
if(num_valid_padded != ref_unit_cnt)
{
printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt);
pass = false;
}
if(num_valid_tokens != 1)
{
printf(" FAIL:num_valid[1] got %d;", num_valid_tokens);
pass = false;
}
int n_tiles = num_valid_padded / unit_size;
for(int i = 1; i < n_tiles; i++)
{
if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1))
{
printf(" FAIL:expert_ids not ascending;");
pass = false;
break;
}
}
WeightType wsum = 0;
for(int i = 0; i < topk; i++)
wsum += sorted_weights_host(i * unit_size);
if(std::abs(wsum - 1.0f) > 1e-3f)
{
printf(" FAIL:wsum=%.6f;", static_cast<float>(wsum));
pass = false;
}
bool buf_zeroed = std::all_of(
moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; });
if(!buf_zeroed)
{
printf(" FAIL:moe_buf not zeroed;");
pass = false;
}
}
printf(" valid:%s\n", pass ? "y" : "n");
fflush(stdout);
return pass;
}
int main(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return -1;
std::string pr = args.get_str("pr_i");
bool r = true;
if(pr == "fp16")
r &= run_test<ck_tile::fp16_t, float>(args);
else if(pr == "bf16")
r &= run_test<ck_tile::bf16_t, float>(args);
else
{
printf("unsupported pr_i: %s\n", pr.c_str());
return -1;
}
return r ? 0 : -1;
}

View File

@@ -1,99 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "topk_softmax_decode_api.hpp"
#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \
constexpr ck_tile::index_t ts_experts = experts_; \
constexpr bool ts_use_softmax = use_softmax_; \
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
ts_weight_type, \
ts_index_type, \
ts_experts, \
ts_use_softmax>; \
using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline<ts_problem>; \
\
using kernel = ck_tile::TopkSoftmaxDecodeKernel<ts_pipeline>; \
\
auto kargs = kernel::MakeKargs(a); \
\
const dim3 grids = kernel::GridSize(a); \
const dim3 blocks = kernel::BlockSize(); \
\
float ave_time = \
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
\
return ave_time;
#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \
if(t.experts <= 8) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \
} \
else if(t.experts <= 16) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \
} \
else if(t.experts <= 32) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \
} \
else if(t.experts <= 64) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \
} \
else if(t.experts <= 128) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \
} \
else if(t.experts <= 192) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \
} \
else if(t.experts <= 256) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \
} \
else if(t.experts <= 512) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \
} \
else if(t.experts <= 1024) \
{ \
TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\
}
float topk_softmax_decode(topk_softmax_decode_trait t,
topk_softmax_decode_kargs a,
ck_tile::stream_config s)
{
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
{
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
}
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::fp16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
}
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
{
using ts_input_type = ck_tile::bf16_t;
using ts_weight_type = float;
using ts_index_type = ck_tile::index_t;
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
}
return -1;
}

View File

@@ -1,24 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/topk_softmax.hpp"
#include <string>
struct topk_softmax_decode_trait
{
std::string input_type;
std::string weight_type; // currently always float
int experts;
std::string activation; // "softmax" or "sigmoid"
};
struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs
{
};
float topk_softmax_decode(topk_softmax_decode_trait t,
topk_softmax_decode_kargs a,
ck_tile::stream_config s);