mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add unified attention (42_unified_attention) and topk_softmax_decode
Squashed from aghamari/unified-attention-decode-opt branch. 42_unified_attention: CK tile paged-KV attention kernel optimized for decode with 4-tier dispatch (tiny/small/medium/large), 16x16 MFMA, 2D decode grid, head-group merging. Supports hdim=64 GQA-8 and hdim=128 MHA with block_size=32. topk_softmax_decode: fused topk + softmax kernel for M=1 MoE decode. Made-with: Cursor
This commit is contained in:
@@ -1,11 +1,24 @@
|
||||
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
|
||||
add_executable(tile_example_topk_softmax topk_softmax.cpp topk_softmax_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
target_compile_options(tile_example_topk_softmax PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
|
||||
add_executable(tile_example_topk_softmax_decode
|
||||
topk_softmax_decode.cpp
|
||||
topk_softmax_decode_api.cpp
|
||||
topk_softmax_api.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/moe_sorting_api.cpp)
|
||||
target_include_directories(tile_example_topk_softmax_decode PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../13_moe_sorting/)
|
||||
target_compile_definitions(tile_example_topk_softmax_decode PRIVATE
|
||||
CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID=1
|
||||
MOE_SORTING_FMOE_2D_BUF=1)
|
||||
target_compile_options(tile_example_topk_softmax_decode PRIVATE ${EXAMPLE_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
|
||||
314
example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp
Normal file
314
example/ck_tile/09_topk_softmax/topk_softmax_decode.cpp
Normal file
@@ -0,0 +1,314 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <cstring>
|
||||
#include <algorithm>
|
||||
#include <time.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
#include "topk_softmax_decode_api.hpp"
|
||||
#include "topk_softmax_api.hpp"
|
||||
#include "moe_sorting_api.hpp"
|
||||
|
||||
#ifndef CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
|
||||
#define CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID 1
|
||||
#endif
|
||||
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
|
||||
|
||||
// CPU reference: softmax -> topk -> moe_sorting
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool reference_fused(const ck_tile::HostTensor<InputType>& x_host,
|
||||
ck_tile::index_t topk,
|
||||
ck_tile::index_t num_experts,
|
||||
ck_tile::index_t unit_size,
|
||||
ck_tile::HostTensor<IndexType>& ref_sorted_ids,
|
||||
ck_tile::HostTensor<WeightType>& ref_sorted_weights,
|
||||
ck_tile::HostTensor<IndexType>& ref_sorted_expert_ids,
|
||||
ck_tile::index_t& ref_unit_cnt)
|
||||
{
|
||||
auto probs = ck_tile::reference_softmax<InputType, WeightType, WeightType>(x_host);
|
||||
|
||||
ck_tile::HostTensor<WeightType> topk_vals({1, topk});
|
||||
ck_tile::HostTensor<IndexType> topk_idxs({1, topk});
|
||||
ck_tile::reference_topk(probs, topk_vals, topk_idxs, topk);
|
||||
|
||||
ck_tile::HostTensor<IndexType> local_expert_mask({num_experts});
|
||||
ref_unit_cnt = 0;
|
||||
ck_tile::reference_moe_sorting<WeightType, IndexType>(
|
||||
topk_idxs,
|
||||
topk_vals,
|
||||
local_expert_mask,
|
||||
ref_sorted_ids,
|
||||
ref_sorted_weights,
|
||||
ref_sorted_expert_ids,
|
||||
ref_unit_cnt,
|
||||
num_experts,
|
||||
unit_size,
|
||||
1,
|
||||
false,
|
||||
true);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "do CPU validation")
|
||||
.insert("pr_i", "bf16", "input data type: fp16/bf16")
|
||||
.insert("e", "128", "number of experts")
|
||||
.insert("k", "8", "topk")
|
||||
.insert("unit", "32", "unit_size (block_size_M)")
|
||||
.insert("model_dim", "7168", "model dimension for moe_buf zeroing")
|
||||
.insert("seed", "-1", "random seed, -1 = random")
|
||||
.insert("warmup", "5", "warmup iterations")
|
||||
.insert("repeat", "20", "benchmark iterations");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool run_test(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string pr = args.get_str("pr_i");
|
||||
int experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int unit_size = args.get_int("unit");
|
||||
int model_dim = args.get_int("model_dim");
|
||||
int seed = args.get_int("seed");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
if(seed < 0)
|
||||
seed = std::time(nullptr);
|
||||
|
||||
if(topk > experts)
|
||||
{
|
||||
printf("topk %d > experts %d, skip\n", topk, experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
int tokens = 1;
|
||||
int max_num_tokens_padded = topk + experts * unit_size - topk;
|
||||
int max_num_m_blocks = (max_num_tokens_padded + unit_size - 1) / unit_size;
|
||||
|
||||
// Host tensors
|
||||
ck_tile::HostTensor<InputType> x_host({1, experts});
|
||||
{
|
||||
auto rng = ck_tile::FillUniformDistribution_Unique<InputType>{
|
||||
-5.f, 5.f, static_cast<uint32_t>(seed)};
|
||||
ck_tile::HostTensor<InputType> row({experts});
|
||||
rng(row);
|
||||
std::copy(row.begin(), row.end(), x_host.begin());
|
||||
}
|
||||
|
||||
// ---------- Device buffers (shared) ----------
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
// ---------- Fused kernel buffers ----------
|
||||
ck_tile::DeviceMem fused_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
|
||||
ck_tile::DeviceMem fused_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_num_valid(2 * sizeof(IndexType));
|
||||
ck_tile::DeviceMem fused_moe_buf(model_dim * sizeof(WeightType));
|
||||
{
|
||||
std::vector<float> ones(model_dim, 1.0f);
|
||||
fused_moe_buf.ToDevice(ones.data());
|
||||
}
|
||||
|
||||
// ---------- Two-kernel baseline buffers ----------
|
||||
ck_tile::DeviceMem topk_w_dev(topk * sizeof(WeightType));
|
||||
ck_tile::DeviceMem topk_i_dev(topk * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_sorted_ids(max_num_tokens_padded * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_sorted_weights(max_num_tokens_padded * sizeof(WeightType));
|
||||
ck_tile::DeviceMem base_sorted_expert_ids(max_num_m_blocks * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_num_valid(2 * sizeof(IndexType));
|
||||
ck_tile::DeviceMem base_moe_buf(model_dim * sizeof(WeightType));
|
||||
int ws_size = moe_sorting_get_workspace_size(tokens, experts, topk, 0);
|
||||
ck_tile::DeviceMem base_ws(ws_size > 0 ? ws_size : 1);
|
||||
if(ws_size > 0)
|
||||
base_ws.SetZero();
|
||||
|
||||
ck_tile::stream_config sc{nullptr, true, 0, warmup, repeat};
|
||||
|
||||
// ====================== Fused kernel ======================
|
||||
topk_softmax_decode_trait fused_trait{pr, "fp32", experts, "softmax"};
|
||||
topk_softmax_decode_kargs fused_karg{
|
||||
x_dev.GetDeviceBuffer(),
|
||||
experts,
|
||||
topk,
|
||||
experts,
|
||||
true,
|
||||
fused_sorted_ids.GetDeviceBuffer(),
|
||||
fused_sorted_weights.GetDeviceBuffer(),
|
||||
fused_sorted_expert_ids.GetDeviceBuffer(),
|
||||
fused_num_valid.GetDeviceBuffer(),
|
||||
fused_moe_buf.GetDeviceBuffer(),
|
||||
unit_size,
|
||||
model_dim,
|
||||
static_cast<int>(sizeof(WeightType))};
|
||||
|
||||
float ms_fused = topk_softmax_decode(fused_trait, fused_karg, sc);
|
||||
|
||||
// ============= Two-kernel baseline: topk + moe_sorting =============
|
||||
topk_softmax_trait ts_trait{pr, "fp32", experts, "softmax"};
|
||||
topk_softmax_kargs ts_karg{
|
||||
x_dev.GetDeviceBuffer(),
|
||||
topk_w_dev.GetDeviceBuffer(),
|
||||
topk_i_dev.GetDeviceBuffer(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
experts,
|
||||
topk};
|
||||
|
||||
moe_sorting_trait ms_trait{"int32", "fp32", false, true, 0};
|
||||
moe_sorting_args ms_arg{
|
||||
topk_i_dev.GetDeviceBuffer(),
|
||||
topk_w_dev.GetDeviceBuffer(),
|
||||
nullptr,
|
||||
nullptr,
|
||||
base_sorted_ids.GetDeviceBuffer(),
|
||||
base_sorted_weights.GetDeviceBuffer(),
|
||||
base_sorted_expert_ids.GetDeviceBuffer(),
|
||||
base_num_valid.GetDeviceBuffer(),
|
||||
base_moe_buf.GetDeviceBuffer(),
|
||||
ws_size > 0 ? base_ws.GetDeviceBuffer() : nullptr,
|
||||
tokens,
|
||||
unit_size,
|
||||
experts,
|
||||
topk,
|
||||
model_dim,
|
||||
static_cast<int>(sizeof(WeightType))};
|
||||
|
||||
// Time the two kernels together using launch_kernel with two lambdas
|
||||
auto sc_sub = ck_tile::stream_config{nullptr, false, 0, 0, 1};
|
||||
float ms_baseline = ck_tile::launch_kernel(
|
||||
sc,
|
||||
[&](const ck_tile::stream_config&) { topk_softmax(ts_trait, ts_karg, sc_sub); },
|
||||
[&](const ck_tile::stream_config&) { moe_sorting(ms_trait, ms_arg, sc_sub); });
|
||||
|
||||
float speedup = (ms_baseline > 0 && ms_fused > 0) ? ms_baseline / ms_fused : 0;
|
||||
printf("[%s] E:%d, k:%d, unit:%d | fused:%.4fms baseline(topk+sort):%.4fms speedup:%.2fx",
|
||||
pr.c_str(),
|
||||
experts,
|
||||
topk,
|
||||
unit_size,
|
||||
ms_fused,
|
||||
ms_baseline,
|
||||
speedup);
|
||||
|
||||
if(ms_fused < 0 || ms_baseline < 0)
|
||||
{
|
||||
printf(" (not supported)\n");
|
||||
return false;
|
||||
}
|
||||
|
||||
// ====================== Validation ======================
|
||||
bool pass = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<IndexType> sorted_ids_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<WeightType> sorted_weights_host({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexType> sorted_expert_ids_host({max_num_m_blocks});
|
||||
ck_tile::HostTensor<IndexType> num_valid_host({2});
|
||||
std::vector<float> moe_buf_host(model_dim);
|
||||
|
||||
fused_sorted_ids.FromDevice(sorted_ids_host.data());
|
||||
fused_sorted_weights.FromDevice(sorted_weights_host.data());
|
||||
fused_sorted_expert_ids.FromDevice(sorted_expert_ids_host.data());
|
||||
fused_num_valid.FromDevice(num_valid_host.data());
|
||||
fused_moe_buf.FromDevice(moe_buf_host.data());
|
||||
|
||||
ck_tile::HostTensor<IndexType> ref_sorted_ids({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<WeightType> ref_sorted_weights({max_num_tokens_padded});
|
||||
ck_tile::HostTensor<IndexType> ref_sorted_expert_ids({max_num_m_blocks});
|
||||
IndexType sentinel = static_cast<uint32_t>((1 & 0x00ffffff) | ((topk & 0xff) << 24));
|
||||
std::fill(ref_sorted_ids.begin(), ref_sorted_ids.end(), sentinel);
|
||||
std::fill(ref_sorted_weights.begin(), ref_sorted_weights.end(), WeightType(0));
|
||||
std::fill(ref_sorted_expert_ids.begin(), ref_sorted_expert_ids.end(), -1);
|
||||
|
||||
ck_tile::index_t ref_unit_cnt = 0;
|
||||
reference_fused<InputType, WeightType, IndexType>(
|
||||
x_host, topk, experts, unit_size,
|
||||
ref_sorted_ids, ref_sorted_weights, ref_sorted_expert_ids, ref_unit_cnt);
|
||||
|
||||
int num_valid_padded = num_valid_host(0);
|
||||
int num_valid_tokens = num_valid_host(1);
|
||||
|
||||
if(num_valid_padded != ref_unit_cnt)
|
||||
{
|
||||
printf(" FAIL:num_valid[0] got %d ref %d;", num_valid_padded, ref_unit_cnt);
|
||||
pass = false;
|
||||
}
|
||||
if(num_valid_tokens != 1)
|
||||
{
|
||||
printf(" FAIL:num_valid[1] got %d;", num_valid_tokens);
|
||||
pass = false;
|
||||
}
|
||||
|
||||
int n_tiles = num_valid_padded / unit_size;
|
||||
for(int i = 1; i < n_tiles; i++)
|
||||
{
|
||||
if(sorted_expert_ids_host(i) < sorted_expert_ids_host(i - 1))
|
||||
{
|
||||
printf(" FAIL:expert_ids not ascending;");
|
||||
pass = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
WeightType wsum = 0;
|
||||
for(int i = 0; i < topk; i++)
|
||||
wsum += sorted_weights_host(i * unit_size);
|
||||
if(std::abs(wsum - 1.0f) > 1e-3f)
|
||||
{
|
||||
printf(" FAIL:wsum=%.6f;", static_cast<float>(wsum));
|
||||
pass = false;
|
||||
}
|
||||
|
||||
bool buf_zeroed = std::all_of(
|
||||
moe_buf_host.begin(), moe_buf_host.end(), [](float v) { return v == 0.0f; });
|
||||
if(!buf_zeroed)
|
||||
{
|
||||
printf(" FAIL:moe_buf not zeroed;");
|
||||
pass = false;
|
||||
}
|
||||
}
|
||||
|
||||
printf(" valid:%s\n", pass ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
auto [result, args] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
std::string pr = args.get_str("pr_i");
|
||||
bool r = true;
|
||||
|
||||
if(pr == "fp16")
|
||||
r &= run_test<ck_tile::fp16_t, float>(args);
|
||||
else if(pr == "bf16")
|
||||
r &= run_test<ck_tile::bf16_t, float>(args);
|
||||
else
|
||||
{
|
||||
printf("unsupported pr_i: %s\n", pr.c_str());
|
||||
return -1;
|
||||
}
|
||||
|
||||
return r ? 0 : -1;
|
||||
}
|
||||
99
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp
Normal file
99
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.cpp
Normal file
@@ -0,0 +1,99 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "topk_softmax_decode_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DECODE_DISPATCH(experts_, use_softmax_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
constexpr bool ts_use_softmax = use_softmax_; \
|
||||
using ts_problem = ck_tile::TopkSoftmaxWarpPerRowProblem<ts_input_type, \
|
||||
ts_weight_type, \
|
||||
ts_index_type, \
|
||||
ts_experts, \
|
||||
ts_use_softmax>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxDecodePipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxDecodeKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
const dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = \
|
||||
ck_tile::launch_kernel(s, ck_tile::make_kernel<1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
#define TOPK_SOFTMAX_DECODE_EXPERT_LADDER(use_softmax_) \
|
||||
if(t.experts <= 8) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(8, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 16) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(16, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 32) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(32, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 64) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(64, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 128) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(128, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 192) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(192, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 256) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(256, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 512) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(512, use_softmax_) \
|
||||
} \
|
||||
else if(t.experts <= 1024) \
|
||||
{ \
|
||||
TOPK_SOFTMAX_DECODE_DISPATCH(1024, use_softmax_)\
|
||||
}
|
||||
|
||||
float topk_softmax_decode(topk_softmax_decode_trait t,
|
||||
topk_softmax_decode_kargs a,
|
||||
ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "softmax")
|
||||
{
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(true)
|
||||
}
|
||||
else if(t.input_type == "fp16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32" && t.activation == "sigmoid")
|
||||
{
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
TOPK_SOFTMAX_DECODE_EXPERT_LADDER(false)
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
24
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp
Normal file
24
example/ck_tile/09_topk_softmax/topk_softmax_decode_api.hpp
Normal file
@@ -0,0 +1,24 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/topk_softmax.hpp"
|
||||
#include <string>
|
||||
|
||||
struct topk_softmax_decode_trait
|
||||
{
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
std::string activation; // "softmax" or "sigmoid"
|
||||
};
|
||||
|
||||
struct topk_softmax_decode_kargs : public ck_tile::TopkSoftmaxDecodeHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float topk_softmax_decode(topk_softmax_decode_trait t,
|
||||
topk_softmax_decode_kargs a,
|
||||
ck_tile::stream_config s);
|
||||
228
example/ck_tile/42_unified_attention/CMakeLists.txt
Normal file
228
example/ck_tile/42_unified_attention/CMakeLists.txt
Normal file
@@ -0,0 +1,228 @@
|
||||
# Commented out: FMHA fwd/bwd instance generation and codegen commands not used by unified_attention
|
||||
#
|
||||
# set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# # Currently only gfx9 archs are supported by FMHA
|
||||
# list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
# if(NOT INST_TARGETS)
|
||||
# message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
# return()
|
||||
# endif()
|
||||
#
|
||||
# # validate user-specified fmha_fwd API list
|
||||
# set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
|
||||
# set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
# "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
# if(BUILD_TESTING)
|
||||
# # Build instances of all APIs for tests
|
||||
# set(FMHA_FWD_ENABLE_APIS "all")
|
||||
# endif()
|
||||
# if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
|
||||
# set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
|
||||
# endif()
|
||||
#
|
||||
# foreach(api ${FMHA_FWD_ENABLE_APIS})
|
||||
# if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
|
||||
# message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
|
||||
# endif()
|
||||
# endforeach()
|
||||
#
|
||||
# # "fwd" is a must-have api for the fmha_fwd example, add it if not specified
|
||||
# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
# list(PREPEND FMHA_FWD_ENABLE_APIS "fwd")
|
||||
# endif()
|
||||
#
|
||||
# file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
|
||||
# ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
# ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
|
||||
# )
|
||||
# set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
|
||||
#
|
||||
# string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
|
||||
# set(FMHA_FWD_CODE_GEN_COMMON_ARGS
|
||||
# ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
# --api ${FMHA_FWD_APIS}
|
||||
# --optdim 32,64,128,256
|
||||
# )
|
||||
# set(FMHA_BWD_CODE_GEN_COMMON_ARGS
|
||||
# ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
# --api bwd
|
||||
# --receipt 3
|
||||
# --optdim 32,64,96,128,256
|
||||
# )
|
||||
#
|
||||
# if(BUILD_TESTING)
|
||||
# list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
|
||||
# endif()
|
||||
#
|
||||
# execute_process(
|
||||
# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
|
||||
# RESULT_VARIABLE ret
|
||||
# )
|
||||
# if(ret AND NOT ret EQUAL 0)
|
||||
# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
|
||||
# endif()
|
||||
#
|
||||
# execute_process(
|
||||
# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
|
||||
# RESULT_VARIABLE ret
|
||||
# )
|
||||
# if(ret AND NOT ret EQUAL 0)
|
||||
# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
|
||||
# endif()
|
||||
#
|
||||
# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
|
||||
# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
|
||||
#
|
||||
# add_custom_command(
|
||||
# OUTPUT ${FMHA_FWD_GEN_BLOBS}
|
||||
# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
|
||||
# --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
# )
|
||||
#
|
||||
# add_custom_command(
|
||||
# OUTPUT ${FMHA_BWD_GEN_BLOBS}
|
||||
# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
|
||||
# --output_dir ${CMAKE_CURRENT_BINARY_DIR}
|
||||
# DEPENDS ${CODE_GEN_SCRIPTS}
|
||||
# )
|
||||
#
|
||||
# set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
|
||||
# set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
|
||||
#
|
||||
# message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
|
||||
# add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
|
||||
# target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
# target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
|
||||
# set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
# set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
#
|
||||
# message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}")
|
||||
# add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
|
||||
# target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
# target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS})
|
||||
# set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
|
||||
# set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
|
||||
#
|
||||
# set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS)
|
||||
# set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS)
|
||||
# set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS)
|
||||
# set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS)
|
||||
#
|
||||
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
|
||||
#
|
||||
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
|
||||
#
|
||||
# if(NOT DEFINED FMHA_FWD_FAST_EXP2)
|
||||
# set(FMHA_FWD_FAST_EXP2 ON)
|
||||
# endif()
|
||||
#
|
||||
# if(FMHA_FWD_FAST_EXP2)
|
||||
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
|
||||
# else()
|
||||
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
|
||||
# endif()
|
||||
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
|
||||
#
|
||||
# if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
|
||||
# else()
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
|
||||
# endif()
|
||||
#
|
||||
# if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
|
||||
# else()
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
|
||||
# endif()
|
||||
#
|
||||
# if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
|
||||
# else()
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
|
||||
# endif()
|
||||
#
|
||||
# if(CK_USE_OCP_FP8)
|
||||
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
|
||||
# endif()
|
||||
#
|
||||
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
# list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
|
||||
#
|
||||
# target_compile_options(${FMHA_FWD_INSTANCES}
|
||||
# PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS}
|
||||
# INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS})
|
||||
# target_compile_options(${FMHA_BWD_INSTANCES}
|
||||
# PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS}
|
||||
# INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS})
|
||||
#
|
||||
# set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
|
||||
# set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
|
||||
#
|
||||
# message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
|
||||
# add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
|
||||
# target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
|
||||
# target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
#
|
||||
# message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
|
||||
# add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
|
||||
# target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
|
||||
# target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
#
|
||||
# set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
|
||||
# --- Unified Attention target (kept) ---
|
||||
|
||||
#
|
||||
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
|
||||
# Currently only gfx9 archs are supported by FMHA
|
||||
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
|
||||
if(NOT INST_TARGETS)
|
||||
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
|
||||
return()
|
||||
endif()
|
||||
|
||||
set(EXAMPLE_UNIFIED_ATTENTION "tile_example_unified_attention")
|
||||
message(DEBUG "adding example ${EXAMPLE_UNIFIED_ATTENTION}")
|
||||
|
||||
add_executable(${EXAMPLE_UNIFIED_ATTENTION} EXCLUDE_FROM_ALL example_unified_attention.cpp)
|
||||
target_include_directories(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
file(GLOB UNIFIED_ATTENTION_INSTANCES CONFIGURE_DEPENDS
|
||||
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
|
||||
)
|
||||
target_sources(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE
|
||||
unified_attention.cpp
|
||||
${UNIFIED_ATTENTION_INSTANCES}
|
||||
)
|
||||
|
||||
set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS)
|
||||
list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS
|
||||
-fgpu-flush-denormals-to-zero
|
||||
-Wno-undefined-func-template
|
||||
--save-temps
|
||||
)
|
||||
set(EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS)
|
||||
|
||||
check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
|
||||
if(HAS_DISABLE_PACKED_FP32)
|
||||
list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS
|
||||
-mllvm --amdgpu-disable-packed-fp32=1
|
||||
)
|
||||
list(APPEND EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS
|
||||
-DCK_TILE_DISABLE_PACKED_FP32=1
|
||||
)
|
||||
endif()
|
||||
|
||||
target_compile_options(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_OPTIONS})
|
||||
target_compile_definitions(${EXAMPLE_UNIFIED_ATTENTION} PRIVATE ${EXAMPLE_UNIFIED_ATTENTION_COMPILE_DEFINITIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
259
example/ck_tile/42_unified_attention/OPTIMIZATION_SUMMARY.md
Normal file
259
example/ck_tile/42_unified_attention/OPTIMIZATION_SUMMARY.md
Normal file
@@ -0,0 +1,259 @@
|
||||
# CK Unified Attention Optimization Summary
|
||||
|
||||
## Overview
|
||||
|
||||
Optimized the CK (Composable Kernel) unified attention kernel for d64 GQA-8 (DeepSeek-V3/R1 config: 64 query heads, 8 KV heads, head_dim=64) on MI350 (gfx950).
|
||||
|
||||
**Result: CK wins ~68% of shapes and is 1-4% faster than Triton end-to-end on production traces.**
|
||||
|
||||
| Metric | Before | After |
|
||||
|--------|--------|-------|
|
||||
| CK-winning shapes | 100/363 (27.5%) | **~248/363 (68%)** |
|
||||
| Decode (weighted) | CK 36% slower | **CK 4-6% faster** |
|
||||
| Prefill (weighted) | CK 36% slower | **CK ~tied or slightly faster** |
|
||||
| Worst-case ratio | 3.55x slower | **~1.2x** |
|
||||
|
||||
---
|
||||
|
||||
## Optimization 1: Single Warp Group Serial Pipeline
|
||||
|
||||
**Problem:** The original pipeline required `NumWarpGroups == 2` (8 warps, 512 threads), wasting resources for decode with small Q tiles.
|
||||
|
||||
**Fix:** Relaxed the assertion and added a serial pipeline path for `NumWarpGroups == 1`:
|
||||
|
||||
```cpp
|
||||
// unified_attention_pipeline.hpp
|
||||
constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
|
||||
static_assert(NumWarpGroups == 1 || NumWarpGroups == 2);
|
||||
|
||||
// ...
|
||||
if constexpr(NumWarpGroups == 1)
|
||||
{
|
||||
// Serial pipeline: load V → PV GEMM → load K → QK GEMM → softmax
|
||||
// No warp group interleaving needed
|
||||
}
|
||||
```
|
||||
|
||||
Key constraint discovered: `kv_tile` is a **union** (K and V share registers), so PV GEMM must finish before K is loaded.
|
||||
|
||||
**Impact:** Enabled 4-warp and 2-warp decode kernels. ~1.7x speedup on 64-seq decode.
|
||||
|
||||
---
|
||||
|
||||
## Optimization 2: Async Prefetch Overlap
|
||||
|
||||
**Problem:** The serial pipeline loaded K/V synchronously, then computed, with no overlap.
|
||||
|
||||
**Fix:** Issue next iteration's global→LDS copies immediately after the barrier, overlapping with current GEMM compute:
|
||||
|
||||
```cpp
|
||||
// Start next K/V loads right after barrier (overlap with compute below)
|
||||
if(i_total_loops + 1 < num_total_loop)
|
||||
K_mem_load(number<1>{}); // async: next K → LDS
|
||||
V_mem_load(number<0>{}); // async: next V → LDS
|
||||
|
||||
// Current iteration compute (overlaps with async loads above)
|
||||
V_lds_load(number<1>{}); // read current V from LDS
|
||||
fmha_alu1(number<0>{}); // softmax
|
||||
gemm(number<0>{}, number<1>{}); // PV GEMM
|
||||
K_lds_load(number<0>{}); // read current K from LDS
|
||||
gemm(number<0>{}, number<0>{}); // QK GEMM
|
||||
```
|
||||
|
||||
**Impact:** ~5% speedup on decode.
|
||||
|
||||
---
|
||||
|
||||
## Optimization 3: 2-Warp Decode Kernel (kBlockM=64)
|
||||
|
||||
**Problem:** 4-warp kernel with kBlockM=128 and kBlockQ=16 wastes 15/16 Q tile rows for decode.
|
||||
|
||||
**Fix:** Created `UnifiedAttentionPipelineDecodePolicy` with `NumWarpPerGroup=2`, enabling `sequence<2,1,1>` (2 warps):
|
||||
|
||||
```cpp
|
||||
struct UnifiedAttentionPipelineDecodePolicy : UnifiedAttentionPipelineDefaultPolicy
|
||||
{
|
||||
static constexpr ck_tile::index_t NumWarpPerGroup = 2;
|
||||
static constexpr ck_tile::index_t NumThreadPerWarpGroup =
|
||||
NumWarpPerGroup * ck_tile::get_warp_size();
|
||||
};
|
||||
```
|
||||
|
||||
kBlockM=64, kBlockQ=8 for GQA-8. Reduced tile waste from 15/16 to 7/8.
|
||||
|
||||
**Impact:** Additional ~5% on decode.
|
||||
|
||||
---
|
||||
|
||||
## Optimization 4: Early Exit + 2D Decode Grid
|
||||
|
||||
**Problem:** The 1D grid with binary search (`find_seq_idx`) had overhead and padding blocks.
|
||||
|
||||
**Fix:** For pure decode, use `dim3(num_kv_heads, num_seqs)` detected by `gridDim.y > 1`:
|
||||
|
||||
```cpp
|
||||
// unified_attention_kernel.hpp
|
||||
CK_TILE_HOST static constexpr auto GridSizeDecode(index_t num_kv_heads, index_t num_seqs)
|
||||
{
|
||||
return dim3(num_kv_heads, num_seqs);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
if(gridDim.y > 1)
|
||||
{
|
||||
// Direct mapping: no binary search, no padding CTAs
|
||||
kv_head_idx = blockIdx.x;
|
||||
seq_idx = blockIdx.y;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Standard 1D grid with binary search
|
||||
// ...
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Also moved the early-exit check before LDS allocation and binary search.
|
||||
|
||||
**Impact:** ~3% on high-batch decode.
|
||||
|
||||
---
|
||||
|
||||
## Optimization 5: 16x16 MFMA Tiny Decode (kBlockM=16, kBlockQ=2)
|
||||
|
||||
**Problem:** With 32x32 MFMA, minimum kBlockM=32 (1 warp), kBlockQ=4. Triton uses BLOCK_Q=2.
|
||||
|
||||
**Fix:** Use 16x16x32 MFMA instruction with `sequence<16,16,32>` warp tile. The softmax `permlane32_swap` reduction assumes 32x32 MFMA lane layout, so added a conditional fallback:
|
||||
|
||||
```cpp
|
||||
// unified_attention_pipeline.hpp
|
||||
static constexpr ck_tile::index_t kWarpGemmM =
|
||||
UnifiedAttentionShape::Gemm0WarpTile::at(ck_tile::number<0>{});
|
||||
|
||||
// In fmha_alu0 and fmha_alu1:
|
||||
#if defined(__gfx950__)
|
||||
if constexpr(kWarpGemmM == 32)
|
||||
{
|
||||
// permlane32_swap for 32x32 MFMA (2 lanes per row)
|
||||
int32x2_t swapped_regs = __builtin_amdgcn_permlane32_swap(...);
|
||||
m_latest.thread_buf_[0] = f_max(swapped_regs.x, swapped_regs.y);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Generic reduction for 16x16 MFMA (4 lanes per row)
|
||||
block_tile_reduce_sync(m_latest, f_max, bool_constant<false>{});
|
||||
}
|
||||
#endif
|
||||
```
|
||||
|
||||
New traits with `TinyDecodePolicy` (`NumWarpPerGroup=1`):
|
||||
|
||||
```cpp
|
||||
struct unified_attention_decode_tiny_kernel_traits
|
||||
{
|
||||
static constexpr index_t kBlockM = 16;
|
||||
static constexpr index_t BLOCK_SIZE = 64; // kPageBlockSize
|
||||
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
|
||||
using unified_attention_block_warps = sequence<1, 1, 1>;
|
||||
// ...
|
||||
};
|
||||
```
|
||||
|
||||
**Impact:** This was the breakthrough. CK went from 37% to 68% win rate. Matches Triton's BLOCK_Q=2 exactly.
|
||||
|
||||
---
|
||||
|
||||
## Optimization 6: 4-Tier Dispatch Heuristic
|
||||
|
||||
**Problem:** Single kernel config for all shapes.
|
||||
|
||||
**Fix:** Shape-adaptive dispatch based on average query length:
|
||||
|
||||
```cpp
|
||||
static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
{
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens;
|
||||
|
||||
if(avg_q <= 2) return tile_tier::tiny; // 1 warp, 16x16 MFMA, kBlockM=16
|
||||
if(avg_q <= 8) return tile_tier::small; // 2 warps, kBlockM=64
|
||||
return tile_tier::medium; // 4 warps, kBlockM=128 (all prefill)
|
||||
}
|
||||
```
|
||||
|
||||
Verified by exhaustive sweep: 4-warp kBlockM=128 outperforms 8-warp kBlockM=256 on **all 71 prefill shapes** (0 exceptions).
|
||||
|
||||
**Impact:** 15-45% improvement on prefill shapes.
|
||||
|
||||
---
|
||||
|
||||
## Kernel Configurations
|
||||
|
||||
| Tier | Warps | MFMA | kBlockM | kBlockQ (GQA-8) | Policy | Use Case |
|
||||
|------|-------|------|---------|-----------------|--------|----------|
|
||||
| Tiny | 1 | 16x16x32 | 16 | 2 | TinyDecode | Pure decode (avg_q ≤ 2) |
|
||||
| Small | 2 | 32x32x16 | 64 | 8 | Decode | Short decode (avg_q ≤ 8) |
|
||||
| Medium | 4 | 32x32x16 | 128 | 16 | Default | All prefill |
|
||||
| Large | 8 | 32x32x16 | 256 | 32 | Default | Unused (4-warp always better) |
|
||||
|
||||
---
|
||||
|
||||
## Instance Files
|
||||
|
||||
20 instance files covering d64/d128 × bf16/fp16 × mask/nomask × decode tiers:
|
||||
|
||||
```
|
||||
instances/unified_attention_d64_bf16_mask_gqa8.cpp # prefill (medium)
|
||||
instances/unified_attention_d64_bf16_mask_gqa8_decode.cpp # small decode
|
||||
instances/unified_attention_d64_bf16_mask_gqa8_decode_s.cpp # small decode (2D grid)
|
||||
instances/unified_attention_d64_bf16_mask_gqa8_decode_t.cpp # tiny decode (16x16 MFMA)
|
||||
# ... (same pattern for bf16_nmask, fp16_mask, fp16_nmask, d128 variants)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## What Didn't Work
|
||||
|
||||
| Attempt | Why it failed |
|
||||
|---------|--------------|
|
||||
| kBlockM=64 with 2x2 warp layout | `permlane32_swap` assumes 1D warp layout; 2D breaks softmax reduction |
|
||||
| 1-warp kBlockM=32 (32x32 MFMA) | Reduced memory bandwidth (1 warp) cancelled the tile waste savings |
|
||||
| sp buffer 2→1 | VGPRs stayed at 132 (compiler minimum); slight decode regression from changed scheduling |
|
||||
| kBlockPerCu=4 | `__launch_bounds__` hint didn't force VGPR reduction on ROCm |
|
||||
| LDS padding changes | Inter-warp padding irrelevant for 1-warp; intra-warp conflicts from MFMA access pattern |
|
||||
| kPageBlockSize=32 | 88 VGPRs / 5 waves, but 2x more KV iterations → 27% slower on low-batch decode |
|
||||
| FMHA develop branch | Standard FMHA fwd kernel 4.6x slower than our decode kernel on 64-seq |
|
||||
|
||||
---
|
||||
|
||||
## Profile (512-seq decode, MI350/gfx950)
|
||||
|
||||
| Resource | Value | Limit | Occupancy |
|
||||
|----------|-------|-------|-----------|
|
||||
| VGPRs | 132 | 512/SIMD | 3 waves/SIMD |
|
||||
| LDS | 38 KB | 160 KB/CU | 4 WGs/CU |
|
||||
| Threads/WG | 64 (1 warp) | - | - |
|
||||
| LDS bank conflicts | 17.8M | - | Intra-warp pattern |
|
||||
|
||||
Bottleneck: VGPRs (132 is compiler minimum for kPageBlockSize=64 with 16x16 MFMA).
|
||||
|
||||
---
|
||||
|
||||
## Files Modified
|
||||
|
||||
**Pipeline:**
|
||||
- `include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp` — serial pipeline, async prefetch, 16x16 MFMA reduction
|
||||
- `include/ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp` — decode policies
|
||||
|
||||
**Kernel:**
|
||||
- `include/ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp` — 2D decode grid, early exit
|
||||
|
||||
**Dispatch:**
|
||||
- `example/ck_tile/42_unified_attention/unified_attention.cpp` — 4-tier dispatch
|
||||
- `example/ck_tile/42_unified_attention/unified_attention_impl.hpp` — decode kernel traits
|
||||
|
||||
**Instances:**
|
||||
- `example/ck_tile/42_unified_attention/instances/` — 12 new decode instance files
|
||||
|
||||
**aiter JIT:**
|
||||
- `aiter/jit/optCompilerConfig.json` — registered decode instance files
|
||||
161
example/ck_tile/42_unified_attention/README.md
Normal file
161
example/ck_tile/42_unified_attention/README.md
Normal file
@@ -0,0 +1,161 @@
|
||||
# fused multi-head attention
|
||||
|
||||
This folder contains examples for unified attention (fused multi-head attention) using the ck_tile tile-programming implementation. The examples demonstrate the usage of the tile-programming API, as well as the new approach to constructing kernel templates and instantiating them.
|
||||
|
||||
## build
|
||||
|
||||
# in the root of ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_unified_attention -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_unified_attention`
|
||||
|
||||
## kernel
|
||||
|
||||
The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
|
||||
|
||||
There are 2 template parameters for this kernel template.
|
||||
|
||||
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
|
||||
* `EpiloguePipeline` is the last stage of the pipeline. It modifies and stores the result. Post-fusion can be done at this stage though the example only returns the result.
|
||||
|
||||
## codegen
|
||||
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.
|
||||
|
||||
## executable
|
||||
`tile_example_unified_attention` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_unified_attention -?` to list all the arguments. Below is an example of the output (may subject to change)
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
-mode kernel mode. 0:batch, 1:group (default:0)
|
||||
-b batch size (default:2)
|
||||
-h num of head, for q (default:8)
|
||||
-h_k num of head, for k/v, -1 means equal to h (default:-1)
|
||||
if not equal to h, then this is GQA/MQA case
|
||||
-s seqlen_q. if group-mode, means the average value of seqlen_q (default:3328)
|
||||
total_seqlen_q = seqlen_q * batch, and seqlen_q per batch may vary
|
||||
also with "-s=s0,s1,s2..." comma seperated int to set per batch seqlen(group-mode)
|
||||
-s_k seqlen_k (including new key/value), -1 means equal to s (default:-1)
|
||||
also with "-s_k=s0,s1,s2..." comma-separated ints to set seqlen per batch (group mode)
|
||||
-s_qpad seqlen_q stride between 2 batches (group-mode optional) (default:-1)
|
||||
Provide positive strides per-batch to simulate physical padding on Q
|
||||
-s_kpad seqlen_k stride between 2 batches, currently used in group-mode only (default:-1)
|
||||
for kv-cache case, each batch [1,s,h,d]/[1,h,s,d] can have a stride
|
||||
along seqlen, instead of packed, same as xformer kv_padding,
|
||||
must be greater than or equal to s_k
|
||||
-d head dim for q, k (default:128)
|
||||
-d_v head dim for v, -1 means equal to d (default:-1)
|
||||
-scale_s scale factor of S. 0 means equal to 1/sqrt(hdim). (default:0)
|
||||
note when squant=1, this value will be modified by range_q/k
|
||||
-range_q per-tensor quantization range of q. used if squant=1. (default:16)
|
||||
-range_k per-tensor quantization range of k. used if squant=1. (default:16)
|
||||
-range_v per-tensor quantization range of v. used if squant=1. (default:16)
|
||||
-range_p per-tensor quantization range of p [e^(s-m)]. used if squant=1. (default:1)
|
||||
-range_o per-tensor quantization range of o (p*v). used if squant=1. (default:16)
|
||||
-squant if using static quantization fusion or not. auto: fp8 will default use squant, other will not (default:auto)
|
||||
0: no static quant(not implemented) 1: apply scale_p and scale_o with respect to P and O.
|
||||
calculate scale_s, scale_p, scale_o according to range_q, range_k, range_v, range_p, range_o
|
||||
-iperm permute input (default:1)
|
||||
if true, will be b*h*s*d, else b*s*h*d
|
||||
-operm permute output (default:1)
|
||||
-bias n or 0, no bias (default:n)
|
||||
e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s
|
||||
a(libi) or 2, alibi with 1*h. a:1, b*h
|
||||
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
|
||||
-mask 0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b') (default:0)
|
||||
't', top-left causal mask, 'b', bottom-r causal mask
|
||||
't:l,r', top-left sliding window attn(swa) with FA style left right size
|
||||
'b:l,r', bottom-r sliding window attn(swa) with FA style left right size
|
||||
'xt:window_size', xformer style masking from top-left, window_size negative is causal, positive is swa
|
||||
'xb:window_size', xformer style masking from bottom-r, window_size negative is causal, positive is swa
|
||||
'g:y,x', generic attention mask coordinate with y/x size (only debug purpose for now)
|
||||
-vlayout r for row-major(seqlen*hdim), c for col-major(hdim*seqlen) (default:r)
|
||||
-lse 0 not store lse, 1 store lse (default:0)
|
||||
-kname if set to 1 will print kernel name (default:0)
|
||||
-init init method. ui, uniform random int, ni, normalized random int (default:uf)
|
||||
uf, uniform random float, nf, normalized random float, tf, trig float, uf:q, quantization
|
||||
-seed random seed used for initializing input tensors. 0 for non-deterministic seed (default:11939)
|
||||
-drop_seed seed for random number generator (default:1)
|
||||
-drop_offset offset for random number generator (default:0)
|
||||
-drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0)
|
||||
-num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1)
|
||||
-warmup number of iterations before benchmark the kernel (default:5)
|
||||
-repeat number of iterations to benchmark the kernel (default:20)
|
||||
-json 0: No Json, 1: Dump Results in Json format (default:0)
|
||||
-jsonfile json file name to dump results (default:fmha_fwd.json)
|
||||
-q_eff_lens Batch-mode only: per-batch effective seqlen for Q (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
```
|
||||
Example 1: `./bin/tile_example_unified_attention -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_unified_attention -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## Padding Examples
|
||||
Example 3 (Group mode with padding): `./bin/tile_example_unified_attention -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
|
||||
|
||||
Example 4 (Batch mode with effective lengths): `./bin/tile_example_unified_attention -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
|
||||
### hdim
|
||||
Currently we support `32/64/128/256` hdim for `fp16`/`bf16`, within which `64`/`128` is better optimized. hdim should be multiple of 8, while seqlen_s can be arbitrary. For hdim be arbitrary number, it can be support through padding kernel of `qr` pipeline (we didn't generate this in generate.py by default)
|
||||
|
||||
### group/batch mode
|
||||
Currently we support both `batch mode` and `group mode` (or `varlen`, in FA's term), by setting `-mode` = `0` or `1`. In `group mode` different kind of attention mask is also supported(see below)
|
||||
|
||||
### MQA/GQA
|
||||
By setting `-h`(nhead for q) and `-h_k`(nhead for k/v) with different number, you can achieve MQA/GQA. Please pay attention that `h % h_K == 0` when you set different numbers.
|
||||
|
||||
### input/output permute, and `b*s*3*h*d`
|
||||
If you look at the kernel argument inside `fmha_fwd_kernel.hpp`, we support providing arbitrary stride for seqlen(stride_q/k/v), nhead, batch of q/k/v matrix, hence it is very flexible to support `b*h*s*d` or `b*s*h*d` input/output permute. The `-iperm=0/1`, `-operm=0/1` is a convenient way to achieve this through the executable. We didn't provide a command-line arg to test `b*s*3*h*d` layout which is by default used by torch/FA, but it's trivial to achieve this if one set the proper `stride_q/k/v` value as `3*h*d`.
|
||||
|
||||
### attention bias
|
||||
Attention bias is supported with the layout of `1*1*s*s`(similiar to input/output, different layout can be supported by changing the stride value for bias, or even extend to `b*h*s*s`) and bias value in float number.
|
||||
|
||||
### alibi
|
||||
alibi is supported
|
||||
|
||||
### lse
|
||||
For training kernels, "log sum exp" need to store out in forward and used in backward. We support this by setting `-lse=1`
|
||||
|
||||
### vlayout
|
||||
We support v matrix in both row-major(`seqlen*hdim`) and col-major(`hdim*seqlen`). Since the accumulate(reduce) dimension for V is along `seqlen`, for current AMD's mfma layout which expect each thread to have contiguous register holding pixels along reduce dimension, it's easier to support col-major V layout. However, the performance of col-major is not necessarily faster than row-major, there are many factors that may affect the overall performance. We still provide the `-vlayout=r/c` here to switch/test between different layouts.
|
||||
|
||||
### attention mask
|
||||
we support `causal mask` and `sliding window attention(swa)` mask in both batch and group mode, either from top-left or bottom-right.
|
||||
Underneath, we unify the mask expression into `generic attention mask coordinate`, providing an uniformed approach for each batch to locate the corresponding pixel need to be masked out.
|
||||

|
||||
|
||||
Since FA/xformer style with window_size_left/right is more popular, we accept window_size as parameter and convert that internally to our generic coordinate(this coordinate can express more cases). Below shows some example of how to achieve different kind of mask through cmdline.
|
||||
|
||||
| mask case| cmdline | FA style | xformer style |
|
||||
|----------|:-------------:|:-------------:|:-------------:|
|
||||
| no mask | `-mask=0`(default) | | |
|
||||
| causal mask from top-left | `-mask=1` or `-mask=t` | `-mask=t:-1,0` | `-mask=xt:-1` |
|
||||
| causal mask from bottom-right | `-mask=2` or `-mask=b` | `-mask=b:-1,0` | `-mask=xb:-1` |
|
||||
| swa from top-left | | `-mask=t:3,5` | `-mask=xt:4` |
|
||||
| swa from bottom-right | | `-mask=b:10,11` | `-mask=xb:16` |
|
||||
|
||||
Note FA use bottom-right by default to express swa case, here we require you explicitly specify top-left/bottom-right.
|
||||
|
||||
### dropout
|
||||
TBD
|
||||
|
||||
### sequence padding and variable length support
|
||||
We support sequence padding and variable-length processing in both batch and group modes fmha forward to handle real-world scenarios where sequences have different lengths.
|
||||
|
||||
**Group Mode Padding**: Use `-s_qpad` and `-s_kpad` to specify physical stride between batches, enabling padded layouts. Each batch can have different logical sequence lengths (`-s`, `-s_k`) but use larger physical strides for memory alignment.
|
||||
|
||||
**Batch Mode Variable Length**: Use `-q_eff_lens` and `-kv_eff_lens` to specify effective sequence lengths per batch. All batches share the same physical sequence length, but the kernel processes only the effective portions. This enables efficient variable-length attention without memory waste.
|
||||
|
||||
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_unified_attention`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
@@ -0,0 +1,679 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <iomanip>
|
||||
#include <iostream>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include <ck_tile/core/numeric/bfloat16.hpp>
|
||||
#include <ck_tile/core/numeric/half.hpp>
|
||||
#include <ck_tile/core/numeric/math.hpp>
|
||||
#include <ck_tile/core/utility/functional.hpp>
|
||||
#include <ck_tile/host/arg_parser.hpp>
|
||||
#include <ck_tile/host/device_memory.hpp>
|
||||
#include <ck_tile/host/fill.hpp>
|
||||
#include <ck_tile/host/check_err.hpp>
|
||||
#include <ck_tile/host/host_tensor.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_gemm.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_masking.hpp>
|
||||
#include <ck_tile/host/reference/reference_batched_softmax.hpp>
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
// const ck_tile::index_t page_blk_size = 32;
|
||||
// num_queries_per_kv is now a runtime arg (see parse_cmd_args)
|
||||
|
||||
auto parse_cmd_args(int argc, char* argv[]) -> std::pair<bool, ck_tile::ArgParser>
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser
|
||||
.insert("prec", "bf16", "data type. fp16/bf16")
|
||||
// .insert("b", "3", "batch size")
|
||||
.insert("nqpkv", "1", "num queries per kv head (GQA ratio, e.g. 1 for MHA, 8 for GQA-8)")
|
||||
.insert("h_k", "8", "num head for k/v. num head for q is nqpkv times this")
|
||||
.insert("s", "3328", "max seqlen_q")
|
||||
.insert("s_k", "-1", "max seqlen_k, -1 means equal to s")
|
||||
.insert("nb", "1024", "num_blks")
|
||||
.insert("b", "3", "batch")
|
||||
.insert("d", "128", "head dim for q & k")
|
||||
.insert("scale_s", "0", "scale factor of S. 0 means equal to 1/sqrt(hdim)")
|
||||
// TODO scale factors
|
||||
.insert("scale", "1", "")
|
||||
.insert("scale_k", "1", "")
|
||||
.insert("scale_v", "1", "")
|
||||
.insert("scale_out", "1", "")
|
||||
.insert("iperm",
|
||||
"0",
|
||||
"permute input\n"
|
||||
"if true, will be b*h*s*d, else b*s*h*d")
|
||||
.insert("operm", "0", "permute output")
|
||||
.insert("causal", "0", "0: no mask, 1: causal mask")
|
||||
.insert("verify", "1", "0:no verify, 1:verify")
|
||||
.insert("varlen", "1", "0: fixed length, 1: variable length")
|
||||
.insert("seed",
|
||||
"11939",
|
||||
"random seed used for initializing input tensors. 0 for "
|
||||
"non-deterministic seed")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "30", "number of iterations to benchmark the kernel")
|
||||
.insert("page_blk_size", "128", "page block size of kv cache")
|
||||
// Optional effective seqlen override (exclude PAD) for batch mode
|
||||
.insert("query_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for Q (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.")
|
||||
.insert("kv_lens",
|
||||
"",
|
||||
"Batch-mode only: per-batch effective seqlen for KV (exclude PAD).\n"
|
||||
"Comma-separated list of length 'b'. If empty, no override.");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_pair(result, arg_parser);
|
||||
}
|
||||
|
||||
auto seqlen_preprocess(ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t max_seqlen_kv,
|
||||
const std::vector<int>& query_lens_input,
|
||||
const std::vector<int>& kv_lens_input,
|
||||
bool varlen) -> std::pair<std::vector<int>, std::vector<int>>
|
||||
{
|
||||
// If both query_lens and kv_lens are provided, return them directly
|
||||
if(!query_lens_input.empty() && !kv_lens_input.empty())
|
||||
{
|
||||
return std::make_pair(query_lens_input, kv_lens_input);
|
||||
}
|
||||
|
||||
std::vector<int> query_lens;
|
||||
std::vector<int> kv_lens;
|
||||
|
||||
if(!varlen)
|
||||
{
|
||||
// Fixed length mode: fill with max seqlen
|
||||
query_lens.assign(batch, max_seqlen_q);
|
||||
kv_lens.assign(batch, max_seqlen_kv);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Variable length mode: generate random lengths up to max
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<int> q_dist(1, max_seqlen_q);
|
||||
std::uniform_int_distribution<int> kv_dist(1, max_seqlen_kv);
|
||||
|
||||
query_lens.resize(batch);
|
||||
kv_lens.resize(batch);
|
||||
|
||||
for(ck_tile::index_t i = 0; i < batch; ++i)
|
||||
{
|
||||
query_lens[i] = q_dist(gen);
|
||||
kv_lens[i] = kv_dist(gen);
|
||||
}
|
||||
}
|
||||
|
||||
return std::make_pair(query_lens, kv_lens);
|
||||
}
|
||||
|
||||
struct Problem
|
||||
{
|
||||
explicit Problem(const ck_tile::ArgParser& args)
|
||||
{
|
||||
data_type = args.get_str("prec") == "fp16"
|
||||
? ck_tile::unified_attention_args::data_type_enum::fp16
|
||||
: ck_tile::unified_attention_args::data_type_enum::bf16;
|
||||
num_blks = args.get_int("nb");
|
||||
nhead_kv = args.get_int("h_k");
|
||||
num_queries_per_kv = args.get_int("nqpkv");
|
||||
nhead_q = nhead_kv * num_queries_per_kv;
|
||||
|
||||
ck_tile::index_t max_seqlen_q = args.get_int("s");
|
||||
ck_tile::index_t max_seqlen_kv = args.get_int("s_k");
|
||||
|
||||
if(max_seqlen_kv == -1)
|
||||
{
|
||||
max_seqlen_kv = max_seqlen_q;
|
||||
}
|
||||
|
||||
hdim = args.get_int("d");
|
||||
query_lens = args.get_int_vec("query_lens");
|
||||
kv_lens = args.get_int_vec("kv_lens");
|
||||
assert(query_lens.size() == kv_lens.size() &&
|
||||
"query_lens and kv_lens must have the same length b");
|
||||
batch = args.get_int("b");
|
||||
page_blk_size = args.get_int("page_blk_size");
|
||||
|
||||
bool varlen = args.get_bool("varlen");
|
||||
auto [query_lens_, kv_lens_] =
|
||||
seqlen_preprocess(batch, max_seqlen_q, max_seqlen_kv, query_lens, kv_lens, varlen);
|
||||
|
||||
query_lens = query_lens_;
|
||||
kv_lens = kv_lens_;
|
||||
batch = query_lens.size();
|
||||
|
||||
// Calculate scale_s
|
||||
scale_s = args.get_float("scale_s");
|
||||
if(scale_s == 0.0f)
|
||||
scale_s = 1.0f / ck_tile::sqrt(static_cast<float>(hdim));
|
||||
|
||||
// Initialize other scales
|
||||
scale = args.get_float("scale");
|
||||
scale_k = args.get_float("scale_k");
|
||||
scale_v = args.get_float("scale_v");
|
||||
num_tokens = 0;
|
||||
for(const auto& len : query_lens)
|
||||
{
|
||||
num_tokens += len;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_query_shape() const { return {num_tokens, nhead_q, hdim}; }
|
||||
|
||||
std::vector<ck_tile::index_t> get_key_shape() const
|
||||
{
|
||||
return {num_blks, page_blk_size, nhead_kv, hdim};
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_value_shape() const
|
||||
{
|
||||
return {num_blks, page_blk_size, nhead_kv, hdim};
|
||||
}
|
||||
|
||||
std::vector<ck_tile::index_t> get_output_shape() const { return {num_tokens, nhead_q, hdim}; }
|
||||
|
||||
ck_tile::unified_attention_args::data_type_enum data_type;
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t num_blks;
|
||||
ck_tile::index_t nhead_q;
|
||||
ck_tile::index_t nhead_kv;
|
||||
ck_tile::index_t num_queries_per_kv;
|
||||
ck_tile::index_t hdim;
|
||||
ck_tile::index_t page_blk_size;
|
||||
ck_tile::index_t num_tokens;
|
||||
float scale_s;
|
||||
float scale;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
mask_info mask;
|
||||
std::vector<int> query_lens;
|
||||
std::vector<int> kv_lens;
|
||||
};
|
||||
|
||||
struct RunConfig
|
||||
{
|
||||
explicit RunConfig(const ck_tile::ArgParser& args)
|
||||
{
|
||||
seed = args.get_uint32("seed");
|
||||
if(*seed == 0)
|
||||
{
|
||||
seed.reset();
|
||||
}
|
||||
|
||||
kernel_warmup = args.get_int("warmup");
|
||||
kernel_repeat = args.get_int("repeat");
|
||||
verify = args.get_bool("verify");
|
||||
}
|
||||
|
||||
std::optional<uint32_t> seed;
|
||||
int kernel_warmup;
|
||||
int kernel_repeat;
|
||||
bool verify;
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
auto generate_qkv(const Problem& problem,
|
||||
[[maybe_unused]] std::optional<uint32_t> seed = std::nullopt)
|
||||
-> std::tuple<ck_tile::HostTensor<DataType>,
|
||||
ck_tile::HostTensor<DataType>,
|
||||
ck_tile::HostTensor<DataType>>
|
||||
{
|
||||
ck_tile::HostTensor<DataType> q(problem.get_query_shape());
|
||||
ck_tile::HostTensor<DataType> k(problem.get_key_shape());
|
||||
ck_tile::HostTensor<DataType> v(problem.get_value_shape());
|
||||
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(q);
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(k);
|
||||
ck_tile::FillNormalDistribution<DataType>{0.f, 3.f, seed}(v);
|
||||
|
||||
return std::make_tuple(q, k, v);
|
||||
}
|
||||
|
||||
namespace host {
|
||||
template <typename AccDataType,
|
||||
typename PDataType,
|
||||
typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename QElementOp,
|
||||
typename KElementOp,
|
||||
typename VElementOp,
|
||||
typename SAccElementOp>
|
||||
CK_TILE_HOST void fmha_fwd(const ck_tile::HostTensor<QDataType>& q_bshd,
|
||||
const ck_tile::HostTensor<KDataType>& k_bshd,
|
||||
const ck_tile::HostTensor<VDataType>& v_bshd,
|
||||
// const mask_info& mask,
|
||||
ck_tile::HostTensor<ODataType>& o_bshd,
|
||||
const QElementOp& q_element_op = {},
|
||||
const KElementOp& k_element_op = {},
|
||||
const VElementOp& v_element_op = {},
|
||||
const SAccElementOp& s_acc_element_op = {})
|
||||
{
|
||||
const int batch_size = q_bshd.mDesc.get_lengths()[0];
|
||||
const int seqlen_q = q_bshd.mDesc.get_lengths()[1];
|
||||
const int seqlen_kv = k_bshd.mDesc.get_lengths()[1];
|
||||
const int nhead_q = q_bshd.mDesc.get_lengths()[2];
|
||||
const int nhead_kv = k_bshd.mDesc.get_lengths()[2];
|
||||
const int hdim_qk = q_bshd.mDesc.get_lengths()[3];
|
||||
const int hdim_v = v_bshd.mDesc.get_lengths()[3];
|
||||
|
||||
const int nr = nhead_q / nhead_kv;
|
||||
|
||||
ck_tile::HostTensor<QDataType> q_host_ref({nhead_q, seqlen_q, hdim_qk});
|
||||
ck_tile::HostTensor<KDataType> k_host_ref({nhead_q, seqlen_kv, hdim_qk});
|
||||
ck_tile::HostTensor<VDataType> v_host_ref({nhead_q, hdim_v, seqlen_kv});
|
||||
ck_tile::HostTensor<ODataType> o_host_ref({nhead_q, seqlen_q, hdim_v});
|
||||
|
||||
ck_tile::HostTensor<AccDataType> s_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
ck_tile::HostTensor<PDataType> p_host_ref({nhead_q, seqlen_q, seqlen_kv});
|
||||
// do computation for each batch
|
||||
for(int b = 0; b < batch_size; ++b)
|
||||
{
|
||||
// copy per-batch data from input tensors
|
||||
// clang-format off
|
||||
q_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = q_bshd(b, idx[1], idx[0] ,
|
||||
idx[2]); });
|
||||
k_host_ref.ForEach([&](auto& self, auto idx) { self(idx) = k_bshd(b, idx[1],
|
||||
idx[0] / nr, idx[2]); });
|
||||
v_host_ref.ForEach([&](auto& self, auto idx) { self(idx) =
|
||||
v_bshd(b, idx[2], idx[0] / nr, idx[1]); });
|
||||
// clang-format on
|
||||
ck_tile::reference_batched_gemm<QDataType, KDataType, AccDataType>(
|
||||
q_host_ref, k_host_ref, s_host_ref, q_element_op, k_element_op, s_acc_element_op);
|
||||
|
||||
ck_tile::reference_batched_masking(
|
||||
s_host_ref,
|
||||
ck_tile::make_generic_attention_mask_from_lr_window<UnifiedAttentionMasks::CausalMask>(
|
||||
-1, 0, seqlen_q, seqlen_kv, 1, false));
|
||||
ck_tile::reference_batched_softmax<AccDataType, AccDataType>(
|
||||
s_host_ref, p_host_ref, ck_tile::identity{});
|
||||
ck_tile::reference_batched_gemm<PDataType, VDataType, AccDataType>(
|
||||
p_host_ref, v_host_ref, o_host_ref, ck_tile::identity{}, v_element_op);
|
||||
|
||||
// copy resulting per-batch data to the output tensor
|
||||
o_host_ref.ForEach(
|
||||
[&](auto& self, auto idx) { o_bshd(b, idx[1], idx[0], idx[2]) = self(idx); });
|
||||
}
|
||||
}
|
||||
} // namespace host
|
||||
|
||||
template <typename DataType>
|
||||
bool run_impl(const Problem& problem, const RunConfig& run_config)
|
||||
{
|
||||
auto [q, k, v] = generate_qkv<DataType>(problem, run_config.seed);
|
||||
|
||||
ck_tile::DeviceMem q_buf(q.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem k_buf(k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem v_buf(v.get_element_space_size_in_bytes());
|
||||
/// FIXME: use correct size for output tensor. just use q size for now since hidm_qk = hdim_v
|
||||
ck_tile::DeviceMem o_buf(q.get_element_space_size_in_bytes());
|
||||
|
||||
q_buf.ToDevice(q.data());
|
||||
k_buf.ToDevice(k.data());
|
||||
v_buf.ToDevice(v.data());
|
||||
// Ensure output buffer is zero-initialized so padded regions compare cleanly
|
||||
o_buf.SetZero();
|
||||
|
||||
ck_tile::unified_attention_args args{};
|
||||
|
||||
args.scale_s = problem.scale_s;
|
||||
args.data_type = problem.data_type;
|
||||
args.num_seqs = problem.batch;
|
||||
args.num_head_q = problem.nhead_q;
|
||||
args.num_queries_per_kv = problem.num_queries_per_kv;
|
||||
args.page_blk_size = problem.page_blk_size;
|
||||
args.mask_type = 2;
|
||||
args.hdim = problem.hdim;
|
||||
|
||||
args.num_blks = problem.num_blks;
|
||||
|
||||
args.q_ptr = q_buf.GetDeviceBuffer();
|
||||
args.query_stride_0 = problem.hdim * problem.nhead_q;
|
||||
args.query_stride_1 = problem.hdim;
|
||||
|
||||
args.k_ptr = k_buf.GetDeviceBuffer();
|
||||
|
||||
args.stride_k_cache_0 = problem.hdim * problem.nhead_kv * problem.page_blk_size;
|
||||
args.stride_k_cache_1 = problem.hdim * problem.nhead_kv;
|
||||
args.stride_k_cache_2 = problem.hdim;
|
||||
args.stride_k_cache_3 = 1;
|
||||
|
||||
args.v_ptr = v_buf.GetDeviceBuffer();
|
||||
args.stride_v_cache_0 = args.stride_k_cache_0;
|
||||
args.stride_v_cache_1 = args.stride_k_cache_1;
|
||||
args.stride_v_cache_2 = args.stride_k_cache_2;
|
||||
args.stride_v_cache_3 = args.stride_k_cache_3;
|
||||
|
||||
args.o_ptr = o_buf.GetDeviceBuffer();
|
||||
args.output_stride_0 = args.query_stride_0;
|
||||
args.output_stride_1 = args.query_stride_1;
|
||||
|
||||
// Optional cumulative seqlen overrides (exclude PAD)
|
||||
auto make_effective_vec = [&](const std::vector<int>& opt_vec, ck_tile::index_t fallback) {
|
||||
std::vector<ck_tile::index_t> eff;
|
||||
if(!opt_vec.empty() && opt_vec[0] != -1)
|
||||
{
|
||||
eff.assign(opt_vec.begin(), opt_vec.end());
|
||||
if(eff.size() < static_cast<size_t>(problem.batch))
|
||||
{
|
||||
eff.resize(problem.batch, eff.back());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
eff.assign(problem.batch, fallback);
|
||||
}
|
||||
return eff;
|
||||
};
|
||||
|
||||
const auto eff_query_lens = make_effective_vec(problem.query_lens, 1024);
|
||||
const auto eff_kv_lens = make_effective_vec(problem.kv_lens, 1024);
|
||||
|
||||
args.num_tokens = std::accumulate(eff_query_lens.begin(), eff_query_lens.end(), 0);
|
||||
|
||||
// Calculate cumulative sums for kernel arguments if varlen is used
|
||||
std::vector<ck_tile::index_t> cu_query_lens;
|
||||
|
||||
auto calculate_cumulative = [&](const std::vector<ck_tile::index_t>& per_batch_vec,
|
||||
std::vector<ck_tile::index_t>& cum_vec) {
|
||||
cum_vec.resize(per_batch_vec.size() + 1);
|
||||
cum_vec[0] = 0;
|
||||
for(std::size_t i = 0; i < per_batch_vec.size(); ++i)
|
||||
cum_vec[i + 1] = cum_vec[i] + per_batch_vec[i];
|
||||
};
|
||||
calculate_cumulative(eff_query_lens, cu_query_lens);
|
||||
|
||||
ck_tile::DeviceMem seq_lens_buf(eff_kv_lens.size() * sizeof(ck_tile::index_t));
|
||||
ck_tile::DeviceMem query_start_len_buf(cu_query_lens.size() * sizeof(ck_tile::index_t));
|
||||
|
||||
seq_lens_buf.ToDevice(eff_kv_lens.data());
|
||||
query_start_len_buf.ToDevice(cu_query_lens.data());
|
||||
|
||||
args.seq_lens_ptr = reinterpret_cast<const ck_tile::index_t*>(seq_lens_buf.GetDeviceBuffer());
|
||||
args.query_start_len_ptr =
|
||||
reinterpret_cast<const ck_tile::index_t*>(query_start_len_buf.GetDeviceBuffer());
|
||||
|
||||
auto max_element = [&](const std::vector<ck_tile::index_t>& opt_vec) {
|
||||
ck_tile::index_t max = opt_vec[0];
|
||||
for(ck_tile::index_t i : opt_vec)
|
||||
{
|
||||
if(i > max)
|
||||
{
|
||||
max = i;
|
||||
}
|
||||
}
|
||||
return max;
|
||||
};
|
||||
|
||||
ck_tile::index_t max_kv_len = max_element(eff_kv_lens);
|
||||
|
||||
ck_tile::index_t max_num_blocks_per_seq =
|
||||
(max_kv_len + problem.page_blk_size - 1) / problem.page_blk_size;
|
||||
|
||||
// Create block_tables
|
||||
ck_tile::DeviceMem block_tables_buf(problem.batch * max_num_blocks_per_seq *
|
||||
sizeof(ck_tile::index_t));
|
||||
|
||||
// Allocate host memory for block_tables
|
||||
std::vector<ck_tile::index_t> block_tables_host(problem.batch * max_num_blocks_per_seq);
|
||||
|
||||
// Fill block_tables with random integers between 0 and num_blocks-1
|
||||
std::mt19937 rng(run_config.seed ? *run_config.seed : std::random_device{}());
|
||||
std::uniform_int_distribution<ck_tile::index_t> dist(0, problem.num_blks - 1);
|
||||
for(size_t i = 0; i < block_tables_host.size(); ++i)
|
||||
{
|
||||
block_tables_host[i] = dist(rng);
|
||||
}
|
||||
|
||||
// Copy to device
|
||||
block_tables_buf.ToDevice(block_tables_host.data());
|
||||
|
||||
// Set pointer in args
|
||||
args.block_tables_ptr =
|
||||
reinterpret_cast<const ck_tile::index_t*>(block_tables_buf.GetDeviceBuffer());
|
||||
args.block_table_stride = max_num_blocks_per_seq;
|
||||
|
||||
ck_tile::stream_config stream_config{nullptr,
|
||||
true,
|
||||
/*log_level=*/0,
|
||||
run_config.kernel_warmup,
|
||||
run_config.kernel_repeat};
|
||||
|
||||
auto [result, time] = ck_tile::unified_attention(args, stream_config);
|
||||
|
||||
if(!result)
|
||||
{
|
||||
std::cerr << "faild to run unified_attention()" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t flop = [&] {
|
||||
long flop_result = 0;
|
||||
|
||||
for(size_t b = 0; b < eff_query_lens.size(); ++b)
|
||||
{
|
||||
long query_lens = eff_query_lens[b];
|
||||
long kv_lens = eff_kv_lens[b];
|
||||
long valid_out_elements = 0;
|
||||
|
||||
// Causal logic for valid output elements
|
||||
if(query_lens > kv_lens)
|
||||
{
|
||||
valid_out_elements = (kv_lens * kv_lens + kv_lens) / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
valid_out_elements =
|
||||
query_lens * kv_lens - ((query_lens * query_lens - query_lens) / 2);
|
||||
}
|
||||
|
||||
flop_result += 2 * problem.nhead_q * valid_out_elements * (problem.hdim + problem.hdim);
|
||||
}
|
||||
return flop_result;
|
||||
}();
|
||||
// TODO fix this
|
||||
// std::size_t flop = 1;
|
||||
float tflops = static_cast<float>(flop) / 1.e9 / time;
|
||||
long mem = 0;
|
||||
|
||||
mem += problem.num_tokens * problem.nhead_q * problem.hdim * 2 * 2; // q and o, fp16
|
||||
// Count unique block indices used in block_tables_host
|
||||
std::unordered_set<ck_tile::index_t> unique_blocks(block_tables_host.begin(),
|
||||
block_tables_host.end());
|
||||
mem += unique_blocks.size() * problem.nhead_kv * problem.hdim * 2 * 2; // k and v, fp16
|
||||
mem += problem.batch * max_num_blocks_per_seq * 4; // int32 block table
|
||||
mem += problem.batch * 4; // int32 seq_lens_ptr
|
||||
|
||||
std::cout << "[" << problem.data_type << "|";
|
||||
std::cout << "] b:" << problem.batch << ", h:" << problem.nhead_q << "/" << problem.nhead_kv
|
||||
<< ", d:" << problem.hdim << ", scale_s:" << problem.scale_s << ", query_lens:[";
|
||||
for(size_t i = 0; i < problem.query_lens.size(); ++i)
|
||||
{
|
||||
std::cout << problem.query_lens[i];
|
||||
if(i < problem.query_lens.size() - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "], kv_lens:[";
|
||||
for(size_t i = 0; i < problem.kv_lens.size(); ++i)
|
||||
{
|
||||
std::cout << problem.kv_lens[i];
|
||||
if(i < problem.kv_lens.size() - 1)
|
||||
std::cout << ",";
|
||||
}
|
||||
std::cout << "], mask:" << "causal mask" << std::fixed << ", " << std::setprecision(8) << time
|
||||
<< " ms, " << std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2)
|
||||
<< (static_cast<double>(mem) / 1e12 / (time / 1e3)) << " TB/s" << std::endl;
|
||||
|
||||
if(!run_config.verify)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// variable lengths are provided -> compute per-batch references
|
||||
// with the effective lengths; else compute a single full reference.
|
||||
// Variable-length aware verification: zero-fill padded region and only compute valid part.
|
||||
ck_tile::HostTensor<DataType> o_ref(problem.get_output_shape());
|
||||
o_ref.SetZero();
|
||||
|
||||
for(int b = 0; b < problem.batch; ++b)
|
||||
{
|
||||
const ck_tile::index_t seqlen_q_eff = eff_query_lens[b];
|
||||
const ck_tile::index_t seqlen_kv_eff = eff_kv_lens[b];
|
||||
|
||||
if(seqlen_q_eff <= 0 || seqlen_kv_eff <= 0)
|
||||
continue;
|
||||
|
||||
// Slice current batch from inputs (bshd) and build single-batch tensors
|
||||
ck_tile::HostTensor<DataType> q_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> k_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> v_b({1, seqlen_kv_eff, problem.nhead_kv, problem.hdim});
|
||||
ck_tile::HostTensor<DataType> o_b({1, seqlen_q_eff, problem.nhead_q, problem.hdim});
|
||||
ck_tile::index_t seq_q_off = cu_query_lens[b];
|
||||
|
||||
// Copy effective region
|
||||
q_b.ForEach([&](auto& self, auto idx) {
|
||||
// idx: [0, s, h, d]
|
||||
self(idx) = q(seq_q_off + idx[1], idx[2], idx[3]);
|
||||
});
|
||||
k_b.ForEach([&](auto& self, auto idx) {
|
||||
// kv cache is paged
|
||||
ck_tile::index_t table_col = int(idx[1] / problem.page_blk_size);
|
||||
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
|
||||
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
|
||||
|
||||
self(idx) = k(block_idx, idx[1] % problem.page_blk_size, idx[2], idx[3]);
|
||||
});
|
||||
v_b.ForEach([&](auto& self, auto idx) {
|
||||
ck_tile::index_t table_col = int(idx[1] / problem.page_blk_size);
|
||||
ck_tile::index_t block_table_offset = b * max_num_blocks_per_seq + table_col;
|
||||
ck_tile::index_t block_idx = block_tables_host[block_table_offset];
|
||||
|
||||
self(idx) = v(block_idx, idx[1] % problem.page_blk_size, idx[2], idx[3]);
|
||||
});
|
||||
// v_b.ForEach([&](auto& self, auto idx) { self(idx) = v(b, idx[1], idx[2], idx[3]); });
|
||||
|
||||
// Compute reference for this batch segment (host::fmha_fwd expects bshd tensors)
|
||||
host::fmha_fwd<float, DataType>(q_b,
|
||||
k_b,
|
||||
v_b,
|
||||
// problem.mask,
|
||||
o_b,
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::identity{},
|
||||
ck_tile::scales{problem.scale_s});
|
||||
|
||||
// Scatter into o_ref's bshd descriptor memory
|
||||
for(int s = 0; s < seqlen_q_eff; ++s)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
o_ref(seq_q_off + s, h, d) = o_b(0, s, h, d);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<DataType> o(problem.get_output_shape());
|
||||
o_buf.FromDevice(o.data());
|
||||
|
||||
const auto [rtol, atol] = [&] {
|
||||
if constexpr(std::is_same_v<DataType, ck_tile::fp16_t>)
|
||||
return std::make_tuple(1e-3, 1e-3);
|
||||
else
|
||||
return std::make_tuple(1e-2, 1e-2);
|
||||
}();
|
||||
|
||||
size_t total = static_cast<size_t>(problem.num_tokens) * static_cast<size_t>(problem.nhead_q) *
|
||||
static_cast<size_t>(problem.hdim);
|
||||
|
||||
size_t nonzero = 0;
|
||||
|
||||
for(int tok = 0; tok < problem.num_tokens; ++tok)
|
||||
{
|
||||
for(int h = 0; h < problem.nhead_q; ++h)
|
||||
{
|
||||
for(int d = 0; d < problem.hdim; ++d)
|
||||
{
|
||||
if(static_cast<float>(o(tok, h, d)) != 0.0f)
|
||||
{
|
||||
nonzero++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float percent =
|
||||
(total > 0) ? (100.0f * static_cast<float>(nonzero) / static_cast<float>(total)) : 0.0f;
|
||||
|
||||
std::cout << "\nNon-zero elements in output tensor o: " << nonzero << " / " << total << " ("
|
||||
<< percent << "%)\n";
|
||||
|
||||
// std::cout << "\n=== Complete Output Tensor (o) ===\n";
|
||||
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
// std::cout << "Token " << tok << ":\n";
|
||||
// for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
// std::cout << " Head " << h << ": ";
|
||||
// for (int d = 0; d < problem.hdim; ++d) {
|
||||
// std::cout << static_cast<float>(o(tok, h, d)) << " ";
|
||||
// }
|
||||
// std::cout << "\n";
|
||||
// }
|
||||
// }
|
||||
|
||||
// std::cout << "\n=== Complete Reference Tensor (o_ref) ===\n";
|
||||
// for (int tok = 0; tok < problem.num_tokens; ++tok) {
|
||||
// std::cout << "Token " << tok << ":\n";
|
||||
// for (int h = 0; h < problem.nhead_q; ++h) {
|
||||
// std::cout << " Head " << h << ": ";
|
||||
// for (int d = 0; d < problem.hdim; ++d) {
|
||||
// std::cout << static_cast<float>(o_ref(tok, h, d)) << " ";
|
||||
// }
|
||||
// std::cout << "\n";
|
||||
// }
|
||||
// }
|
||||
return ck_tile::check_err(o, o_ref, std::string("found incorrect results!"), rtol, atol);
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
|
||||
auto [parse_result, args] = parse_cmd_args(argc, argv);
|
||||
|
||||
if(!parse_result)
|
||||
{
|
||||
std::cerr << "failed to parse command line arguments" << std::endl;
|
||||
}
|
||||
|
||||
Problem problem(args);
|
||||
RunConfig run_config(args);
|
||||
|
||||
const auto run = [&] {
|
||||
if(problem.data_type == ck_tile::unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
return run_impl<ck_tile::fp16_t>(problem, run_config);
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_impl<ck_tile::bf16_t>(problem, run_config);
|
||||
}
|
||||
};
|
||||
|
||||
return !run();
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, true, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::bf16, false, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, true, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 256, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_bs32_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 32, 8, 32>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 128, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_small_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 64, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,14 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
using kernel_traits =
|
||||
unified_attention_decode_tiny_kernel_traits<unified_attention_args::data_type_enum::fp16, false, 64, 16, 8>;
|
||||
|
||||
INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits)
|
||||
|
||||
} // namespace ck_tile
|
||||
167
example/ck_tile/42_unified_attention/mask.hpp
Normal file
167
example/ck_tile/42_unified_attention/mask.hpp
Normal file
@@ -0,0 +1,167 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/unified_attention.hpp"
|
||||
|
||||
// keep this in sync with ck_tile::GenericAttentionMaskEnum
|
||||
enum class mask_enum
|
||||
{
|
||||
no_mask = 0,
|
||||
mask_top_left,
|
||||
mask_bottom_right,
|
||||
window_generic,
|
||||
};
|
||||
|
||||
struct mask_info
|
||||
{
|
||||
mask_enum type;
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t seqlen_k;
|
||||
ck_tile::index_t y, x;
|
||||
ck_tile::index_t left, right; // FA style SWA left/right
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
os << "n";
|
||||
else if(type == mask_enum::mask_top_left)
|
||||
os << "t(" << left << ":" << right << ")";
|
||||
else if(type == mask_enum::mask_bottom_right)
|
||||
os << "b(" << left << ":" << right << ")";
|
||||
else
|
||||
{
|
||||
os << "g(" << y << ":" << x << ")";
|
||||
}
|
||||
}
|
||||
|
||||
static mask_info decode(std::string str, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k)
|
||||
{
|
||||
ck_tile::index_t x_total = seqlen_k;
|
||||
ck_tile::index_t y_total = seqlen_q;
|
||||
mask_info tmp;
|
||||
tmp.seqlen_q = seqlen_q;
|
||||
tmp.seqlen_k = seqlen_k;
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "xt" || t == "xb")
|
||||
{
|
||||
// xformer style sliding window attn from top-left
|
||||
ck_tile::index_t window_size = std::stoi(v);
|
||||
ck_tile::index_t left_size = -1;
|
||||
ck_tile::index_t right_size = 0;
|
||||
if(window_size > 0)
|
||||
{
|
||||
left_size = window_size / 2;
|
||||
right_size = window_size - 1 - left_size;
|
||||
}
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
left_size, right_size, y_total, x_total, t == "xt");
|
||||
|
||||
tmp.type = t == "xt" ? mask_enum::mask_top_left : mask_enum::mask_bottom_right;
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = left_size;
|
||||
tmp.right = right_size;
|
||||
}
|
||||
else if(t == "t" || t == "b" || t == "g")
|
||||
{
|
||||
auto found_1 = v.find(",");
|
||||
if(found_1 == std::string::npos)
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
ck_tile::index_t v0 = std::stoi(v.substr(0, found_1));
|
||||
ck_tile::index_t v1 = std::stoi(v.substr(found_1 + 1));
|
||||
if(t == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, true);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
auto r = ck_tile::make_generic_attention_mask_coordinates_from_lr_window(
|
||||
v0, v1, y_total, x_total, false);
|
||||
tmp.y = r.at(ck_tile::number<0>{});
|
||||
tmp.x = r.at(ck_tile::number<1>{});
|
||||
tmp.left = v0;
|
||||
tmp.right = v1;
|
||||
}
|
||||
else if(t == "g")
|
||||
{
|
||||
tmp.type = mask_enum::window_generic;
|
||||
tmp.y = v0;
|
||||
tmp.x = v1;
|
||||
tmp.left = v0; // TODO: don't use this?
|
||||
tmp.right = v1;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0")
|
||||
{
|
||||
tmp.type = mask_enum::no_mask;
|
||||
}
|
||||
else if(str == "1" || str == "t")
|
||||
{
|
||||
tmp.type = mask_enum::mask_top_left;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else if(str == "2" || str == "b")
|
||||
{
|
||||
tmp.type = mask_enum::mask_bottom_right;
|
||||
tmp.y = seqlen_q;
|
||||
tmp.x = seqlen_k - seqlen_q + 1;
|
||||
tmp.left = -1;
|
||||
tmp.right = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid mask value: " + str);
|
||||
}
|
||||
return tmp;
|
||||
}
|
||||
|
||||
ck_tile::index_t get_unmaskarea() const
|
||||
{
|
||||
if(type == mask_enum::no_mask)
|
||||
return seqlen_q * seqlen_k;
|
||||
ck_tile::index_t area = 0;
|
||||
for(ck_tile::index_t i_y = 0; i_y < seqlen_q; ++i_y)
|
||||
{
|
||||
ck_tile::index_t x_start = std::max(-y + i_y + 1, static_cast<ck_tile::index_t>(0));
|
||||
ck_tile::index_t x_end = std::min(i_y + x, seqlen_k);
|
||||
if(x_end > x_start)
|
||||
{
|
||||
area += (x_end - x_start);
|
||||
}
|
||||
}
|
||||
return area;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const mask_info& mi)
|
||||
{
|
||||
mi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
BIN
example/ck_tile/42_unified_attention/misc/gamc.png
Normal file
BIN
example/ck_tile/42_unified_attention/misc/gamc.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
84
example/ck_tile/42_unified_attention/rotary.hpp
Normal file
84
example/ck_tile/42_unified_attention/rotary.hpp
Normal file
@@ -0,0 +1,84 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <functional>
|
||||
#include <iterator>
|
||||
#include <optional>
|
||||
#include <random>
|
||||
#include <tuple>
|
||||
|
||||
// keep sync with RotaryEmbeddingEnum
|
||||
enum class rope_enum
|
||||
{
|
||||
none = 0,
|
||||
interleaved = 1,
|
||||
half_rotated = 2,
|
||||
};
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
generate_rotary_cos_sin(ck_tile::index_t seqlen,
|
||||
ck_tile::index_t rotary_dim,
|
||||
std::optional<unsigned> seed = std::nullopt)
|
||||
{
|
||||
// return dummy tensors if we won't apply RoPE at all
|
||||
if(rotary_dim <= 0)
|
||||
{
|
||||
ck_tile::HostTensor<DataType> dummy({1, 1});
|
||||
return std::make_tuple(dummy, dummy);
|
||||
}
|
||||
|
||||
std::mt19937 random_engine(seed.has_value() ? *seed : std::random_device{}());
|
||||
std::uniform_real_distribution<float> generator(0.0f, 1.0f);
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen * 2;
|
||||
const ck_tile::index_t num_cols = rotary_dim / 2;
|
||||
|
||||
using std::begin, std::end;
|
||||
|
||||
ck_tile::HostTensor<float> angle({num_rows, num_cols});
|
||||
std::generate(begin(angle), end(angle), [&] { return generator(random_engine) * 2 * M_PI; });
|
||||
|
||||
ck_tile::HostTensor<DataType> cos({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(cos), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::cos(origin_value));
|
||||
});
|
||||
|
||||
ck_tile::HostTensor<DataType> sin({num_rows, num_cols});
|
||||
std::transform(begin(angle), end(angle), begin(sin), [](float origin_value) {
|
||||
return ck_tile::type_convert<DataType>(std::sin(origin_value));
|
||||
});
|
||||
|
||||
return std::make_tuple(cos, sin);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
std::tuple<ck_tile::HostTensor<DataType>, ck_tile::HostTensor<DataType>>
|
||||
slice_rotary_cos_sin(const ck_tile::HostTensor<DataType>& cos,
|
||||
const ck_tile::HostTensor<DataType>& sin,
|
||||
ck_tile::index_t seqlen_offset,
|
||||
ck_tile::index_t seqlen)
|
||||
{
|
||||
assert(cos.get_num_of_dimension() == 2 && sin.get_num_of_dimension() == 2);
|
||||
assert(cos.get_length(0) == sin.get_length(0) && cos.get_length(1) == sin.get_length(1));
|
||||
|
||||
assert(static_cast<std::size_t>(seqlen_offset + seqlen) <= cos.get_length(0));
|
||||
|
||||
const ck_tile::index_t num_rows = seqlen;
|
||||
const ck_tile::index_t num_cols = cos.get_length(1);
|
||||
|
||||
ck_tile::HostTensor<DataType> cos_pt({num_rows, num_cols});
|
||||
cos_pt.ForEach([&](auto& self, auto i) { self(i) = cos(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
ck_tile::HostTensor<DataType> sin_pt({num_rows, num_cols});
|
||||
sin_pt.ForEach([&](auto& self, auto i) { self(i) = sin(i[0] + seqlen_offset, i[1]); });
|
||||
|
||||
return std::make_tuple(cos_pt, sin_pt);
|
||||
}
|
||||
53
example/ck_tile/42_unified_attention/script/benchmark_fwd.sh
Executable file
53
example/ck_tile/42_unified_attention/script/benchmark_fwd.sh
Executable file
@@ -0,0 +1,53 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_unified_attention -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
#Padding Benchmarks: batch mode (baseline vs low/med/high pad)
|
||||
prec="fp16"
|
||||
base_batch_args="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_batch_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Padding Benchmarks: group mode (baseline vs low/med/high physical pad)
|
||||
seqlens_q="1024,768,512,256"
|
||||
seqlens_k="1024,768,512,256"
|
||||
base_group_args="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=1 -v=$VALID"
|
||||
|
||||
# baseline (no physical pad)
|
||||
$EXE $base_group_args
|
||||
|
||||
# low physical pad
|
||||
$EXE $base_group_args -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
|
||||
# medium physical pad
|
||||
$EXE $base_group_args -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
|
||||
# high physical pad
|
||||
$EXE $base_group_args -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
48
example/ck_tile/42_unified_attention/script/run_full_test.sh
Executable file
48
example/ck_tile/42_unified_attention/script/run_full_test.sh
Executable file
@@ -0,0 +1,48 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# in order to run this script you'd first need to build the tile_example_fmha_fwd and tile_eaxmple_fmha_bwd executables in ../build/bin/
|
||||
#
|
||||
# run the script as "./run_full_test.sh <tag for your test environment> <branch name> <host name> <gpu_arch>
|
||||
# input arguments:
|
||||
# environment tag : a string describing the specifics of your test environment
|
||||
# branch name : name of the branch in git repo (git status | grep -e 'On branch')
|
||||
# host name : $hostname
|
||||
# gpu architecture: e.g., gfx90a, or gfx942, etc.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
#get the command line arguments:
|
||||
export env_type=$1
|
||||
echo 'Environment type: ' $env_type
|
||||
export branch=$2
|
||||
echo 'Branch name: ' $branch
|
||||
export host_name=$3
|
||||
echo 'Host name: ' $host_name
|
||||
export GPU_arch=$4
|
||||
echo 'GPU_arch: ' $GPU_arch
|
||||
|
||||
function print_log_header(){
|
||||
rm -f $1;
|
||||
echo 'On branch ' $3 &> $1;
|
||||
echo 'Node name: ' $4 >> $1;
|
||||
#get GPU_arch and number of compute units from rocminfo
|
||||
echo -n "GPU_arch: " >> $1; rocminfo | grep "Name:" | grep "gfx" >> $1;
|
||||
rocminfo | grep "Compute Unit:" >> $1;
|
||||
hipcc --version | grep -e 'HIP version' >> $1;
|
||||
echo 'Environment type: ' $2 >> $1;
|
||||
/opt/rocm/bin/amdclang++ --version | grep -e 'InstalledDir' >> $1;
|
||||
}
|
||||
|
||||
#run verification tests
|
||||
time example/ck_tile/01_fmha/script/smoke_test_fwd.sh
|
||||
time example/ck_tile/01_fmha/script/smoke_test_bwd.sh
|
||||
|
||||
#run performance benchmarks
|
||||
export fmha_fwd_log="perf_fmha_fwd_$GPU_arch.log"
|
||||
print_log_header $fmha_fwd_log $env_type $branch $host_name
|
||||
time example/ck_tile/01_fmha/script/benchmark_fwd.sh 2>&1 | tee -a $fmha_fwd_log
|
||||
|
||||
export fmha_bwd_log="perf_fmha_bwd_$GPU_arch.log"
|
||||
print_log_header $fmha_bwd_log $env_type $branch $host_name
|
||||
time example/ck_tile/01_fmha/script/benchmark_bwd.sh 2>&1 | tee -a $fmha_bwd_log
|
||||
|
||||
90
example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh
Executable file
90
example/ck_tile/42_unified_attention/script/smoke_test_bwd.sh
Executable file
@@ -0,0 +1,90 @@
|
||||
#!/bin/bash
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_bwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=${GPU_arch:-""}
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_bwd_fails_$GPU_arch.txt"}
|
||||
rm -f $CURR_FAILS_FILE
|
||||
touch $CURR_FAILS_FILE
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_bwd_known_fails_$GPU_arch.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1'
|
||||
|
||||
run_exe() {
|
||||
set +ex
|
||||
$EXE $@
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ] ; then
|
||||
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
|
||||
fi
|
||||
set -ex
|
||||
}
|
||||
|
||||
test_h_s_mask() {
|
||||
run_exe -b=1 -h=4 -h_k=2 -s=259 $@
|
||||
run_exe -b=2 -h=2 -s=516 -s_k=253 $@
|
||||
run_exe -b=1 -h=4 -h_k=1 -s=500 -s_k=251 -mask=1 $@
|
||||
run_exe -b=1 -h=2 -s=900 -s_k=258 -mask=2 $@
|
||||
run_exe -b=2 -h=1 -s=987 -s_k=219 -mask=t:128,30 $@
|
||||
run_exe -b=2 -h=3 -h_k=1 -s=244 -s_k=499 -mask=b:4,35 $@
|
||||
}
|
||||
|
||||
set -x
|
||||
# main tests
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for mode in 0 1 ; do
|
||||
for bias in "n" "a" ; do
|
||||
for dbias in 0 ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for deterministic in 0 ; do
|
||||
test_h_s_mask -prec=$prec -d=$hdim -bias=$bias -dbias=$dbias -p_drop=$p_drop -iperm=$perm -operm=$perm -deterministic=$deterministic -v=1 -mode=$mode -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# additional cases
|
||||
for hdim in 40 48 72 96 ; do
|
||||
test_h_s_mask -prec=fp16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=0 -operm=0 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=n -dbias=0 -p_drop=0 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
test_h_s_mask -prec=bf16 -d=$hdim -bias=a -dbias=0 -p_drop=0.2 -iperm=1 -operm=1 -deterministic=0 -v=1 -mode=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f $KNOWN_FAILS_FILE ] ; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$(($known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$(($new_fails_count + 1))
|
||||
fi
|
||||
done < $CURR_FAILS_FILE
|
||||
else
|
||||
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
cat $CURR_FAILS_FILE
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $(($new_fails_count != 0))
|
||||
281
example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh
Executable file
281
example/ck_tile/42_unified_attention/script/smoke_test_fwd.sh
Executable file
@@ -0,0 +1,281 @@
|
||||
#!/bin/bash
|
||||
# TODO: run this script from CK root or build directory
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_DIR=$(cd $(dirname "${BASH_SOURCE[0]}") && pwd)
|
||||
EXE_NAME=tile_example_fmha_fwd
|
||||
EXE="$(find . -name $EXE_NAME -type f | head -n 1)"
|
||||
KNAME=1
|
||||
GPU_arch=$GPU_arch
|
||||
if [ -z "$GPU_arch" ] ; then
|
||||
GPU_arch=$(rocminfo | grep -E 'Name:\s+gfx' | head -n1 | awk '{print $2}')
|
||||
fi
|
||||
|
||||
export CK_WARMUP=0
|
||||
export CK_REPEAT=1
|
||||
|
||||
CURR_FAILS_FILE=${CURR_FAILS_FILE:-"fmha_fwd_fails_$GPU_arch.txt"}
|
||||
rm -f $CURR_FAILS_FILE
|
||||
touch $CURR_FAILS_FILE
|
||||
KNOWN_FAILS_FILE=${KNOWN_FAILS_FILE:-"$SCRIPT_DIR/fmha_fwd_known_fails_$GPU_arch.txt"}
|
||||
|
||||
COMMON_ARGS='-v=1 -warmup=0 -repeat=1'
|
||||
# mode=0
|
||||
# export HIP_VISIBLE_DEVICES=4
|
||||
|
||||
TEST_SPLITKV=0
|
||||
TEST_APPENDKV=0
|
||||
# options:
|
||||
# -s: run splitkv tests
|
||||
# -a: run appendkv tests
|
||||
while getopts ":sa" opt; do
|
||||
case "${opt}" in
|
||||
s)
|
||||
TEST_SPLITKV=1
|
||||
;;
|
||||
a)
|
||||
TEST_APPENDKV=1
|
||||
;;
|
||||
*)
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
run_exe() {
|
||||
set +ex
|
||||
$EXE $@
|
||||
local ret=$?
|
||||
if [ $ret -ne 0 ] ; then
|
||||
echo "$EXE_NAME $*" >> $CURR_FAILS_FILE
|
||||
fi
|
||||
set -ex
|
||||
}
|
||||
|
||||
run_fp16_bf16_tests() {
|
||||
local NUM_SPLITS="1"
|
||||
local PAGE_BLOCK_SIZE="0"
|
||||
local CACHE_BATCH_IDX="0"
|
||||
|
||||
if [ $TEST_SPLITKV -eq 1 ] ; then
|
||||
NUM_SPLITS="$NUM_SPLITS 2 3"
|
||||
PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128"
|
||||
CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1"
|
||||
fi
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for mode in 1 0 ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for lse in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for p_drop in 0.0 0.2 ; do
|
||||
for num_splits in $NUM_SPLITS ; do
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# run_exe -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16 -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
run_exe -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8bf16_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8bf16 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp8fp32_tests() {
|
||||
for perm in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
for b in 1 2 ; do
|
||||
for hdim in 64 128 256 ; do
|
||||
|
||||
$EXE -prec=fp8fp32 -init=0 -b=$b -h=1 -d=128 -s=128 -bias=$bias -iperm=$perm -operm=$perm -vlayout=r -squant=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done
|
||||
}
|
||||
|
||||
run_fp16_appendkv_tests() {
|
||||
for s in $(seq 63 1 65) ; do
|
||||
for s_k in 65 129 ; do
|
||||
for s_knew in 0 64 $s_k ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for ri in 0 1 ; do
|
||||
for rdim in 0 16 32 $hdim ; do
|
||||
for page_block_size in 0 128 ; do
|
||||
for cache_batch_idx in 0 1 ; do
|
||||
|
||||
run_exe -prec=fp16 -b=3 -h=3 -d=$hdim -s=$s -s_k=$s_k -s_knew=$s_knew -rotary_dim=$rdim -rotary_interleaved=$ri -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -iperm=1 -operm=1 -kname=1 $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done
|
||||
}
|
||||
|
||||
run_padding_smoke_tests() {
|
||||
# Padding-only smoke tests for batch/group mode using COMMON_ARGS
|
||||
local prec="fp16"
|
||||
|
||||
# Batch mode: padding via effective lengths (exclude PAD)
|
||||
# Use lse=1 to select a non-trload kernel and avoid overly strict tolerance mismatches
|
||||
local base_batch="-prec=$prec -mode=0 -b=4 -h=16 -h_k=16 -d=128 -s=1024 -bias=n -mask=0 -lse=1 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_batch -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_batch -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_batch -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
|
||||
# Group mode: padding via physical stride along seqlen
|
||||
local seqlens_q="1024,768,512,256"
|
||||
local seqlens_k="1024,768,512,256"
|
||||
local base_group="-prec=$prec -mode=1 -b=4 -h=16 -h_k=16 -d=128 -s=$seqlens_q -s_k=$seqlens_k -bias=n -mask=0 -lse=0 -iperm=0 -operm=0 -vlayout=r -kname=$KNAME $COMMON_ARGS"
|
||||
# low physical pad
|
||||
$EXE $base_group -s_qpad=1152,896,576,320 -s_kpad=1152,896,576,320
|
||||
# medium physical pad
|
||||
$EXE $base_group -s_qpad=1536,1152,768,384 -s_kpad=1536,1152,768,384
|
||||
# high physical pad
|
||||
$EXE $base_group -s_qpad=2048,1536,1024,512 -s_kpad=2048,1536,1024,512
|
||||
}
|
||||
|
||||
run_padding_basic_boundary_tests() {
|
||||
# Basic padding and boundary tests (reference: smoke_test_fwd_pad.sh)
|
||||
local prec
|
||||
local perm
|
||||
|
||||
# Group mode: Q&K padded with per-batch different strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=16 -d_v=32 \
|
||||
-s=55 -s_k=256 -s_qpad=64,60 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# slightly larger, uneven padding strides
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=3 -h=2 -h_k=1 -d=64 -d_v=64 \
|
||||
-s=50,60,40 -s_k=128,256,192 -s_qpad=64,64,64 -s_kpad=160,288,224 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# only K padded; Q unpadded
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 \
|
||||
-s=55 -s_k=256 -s_kpad=272,260 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# use cu_seqlen overrides to skip tail PAD
|
||||
for prec in fp16 bf16 ; do
|
||||
for perm in 0 1 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -h_k=8 -d=128 -s=3 -s_k=3 \
|
||||
-q_eff_lens=1,2,1,2 -kv_eff_lens=1,2,1,2 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=2 -h=2 -h_k=1 -d=32 -d_v=64 -s=64 -s_k=256 \
|
||||
-q_eff_lens=55,60 -kv_eff_lens=200,256 \
|
||||
-bias=n -p_drop=0.0 -lse=0 -iperm=$perm -operm=$perm \
|
||||
-num_splits=1 -page_block_size=0 -cache_batch_idx=0 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
done
|
||||
|
||||
# no padding (equal), mixed Q/KV, all len=1
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=128,128,128,128 -kv_eff_lens=128,128,128,128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=10,20,30,40 -kv_eff_lens=40,30,20,10 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
$EXE -prec=$prec -mode=0 -b=4 -h=8 -d=64 -s=128 -s_k=128 \
|
||||
-q_eff_lens=1,1,1,1 -kv_eff_lens=1,1,1,1 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# highly variable logical lengths
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=4 -h=4 -d=32 \
|
||||
-s=1,127,3,65 -s_k=1,127,3,65 -s_kpad=128 \
|
||||
-bias=n -p_drop=0.0 -lse=1 -kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
|
||||
# GQA + Alibi + Causal mask (keep vlayout row-major for fp16/bf16
|
||||
for prec in fp16 bf16 ; do
|
||||
$EXE -prec=$prec -mode=1 -b=2 -h=16 -h_k=4 -d=128 \
|
||||
-s=256,129 -s_k=256,129 -s_kpad=256 \
|
||||
-bias=a -mask=t -lse=1 -iperm=0 -operm=0 -vlayout=r \
|
||||
-kname=$KNAME $COMMON_ARGS
|
||||
done
|
||||
}
|
||||
|
||||
set -x
|
||||
|
||||
run_fp16_bf16_tests
|
||||
run_padding_smoke_tests
|
||||
run_padding_basic_boundary_tests
|
||||
run_fp8_tests
|
||||
run_fp8bf16_tests
|
||||
run_fp8fp32_tests
|
||||
|
||||
if [ $TEST_APPENDKV -eq 1 ] ; then
|
||||
run_fp16_appendkv_tests
|
||||
fi
|
||||
|
||||
set +x
|
||||
|
||||
new_fails_count=0
|
||||
known_fails_count=0
|
||||
if [ -f $KNOWN_FAILS_FILE ] ; then
|
||||
echo "Comparing current fails ($CURR_FAILS_FILE) against known fails ($KNOWN_FAILS_FILE):"
|
||||
while IFS= read -r line; do
|
||||
if grep -Fxq "$line" $KNOWN_FAILS_FILE; then
|
||||
echo "Known fail: $line"
|
||||
known_fails_count=$(($known_fails_count + 1))
|
||||
else
|
||||
echo "New fail: $line"
|
||||
new_fails_count=$(($new_fails_count + 1))
|
||||
fi
|
||||
done < $CURR_FAILS_FILE
|
||||
else
|
||||
new_fails_count=$(wc -l < $CURR_FAILS_FILE)
|
||||
echo "No known fails file, all fails ($new_fails_count) are new:"
|
||||
cat $CURR_FAILS_FILE
|
||||
fi
|
||||
echo "New fails count: $new_fails_count; Known fails count: $known_fails_count"
|
||||
exit $(($new_fails_count != 0))
|
||||
218
example/ck_tile/42_unified_attention/unified_attention.cpp
Normal file
218
example/ck_tile/42_unified_attention/unified_attention.cpp
Normal file
@@ -0,0 +1,218 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "unified_attention_impl.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
const unified_attention_args::data_type_enum& data_type)
|
||||
{
|
||||
switch(data_type)
|
||||
{
|
||||
case unified_attention_args::data_type_enum::fp16: return stream << "fp16";
|
||||
case unified_attention_args::data_type_enum::bf16: return stream << "bf16";
|
||||
default: return stream << "unknown";
|
||||
}
|
||||
}
|
||||
|
||||
// Helper macro to reduce dispatch boilerplate.
|
||||
// Dispatches based on DataType, IsMasking, HeadSize, BlockM, NumQPerKV.
|
||||
#define DISPATCH_UNIFIED_ATTENTION(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// Dispatch macros for three tile tiers (default block_size).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_tiny_kernel_traits<DType, IsMask, HSize, BM, NQPKV>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
// block_size=32 dispatch macros (6th template arg = 32).
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_small_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
#define DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(DType, IsMask, HSize, BM, NQPKV) \
|
||||
{ \
|
||||
using kernel_traits = unified_attention_decode_bs32_kernel_traits<DType, IsMask, HSize, BM, NQPKV, 32>; \
|
||||
return unified_attention_kernel_dispatch_decode<kernel_traits>(args, config); \
|
||||
}
|
||||
|
||||
enum class tile_tier { large, medium, small, tiny };
|
||||
|
||||
static tile_tier select_tile_tier(const unified_attention_args& args)
|
||||
{
|
||||
const index_t avg_q = args.num_seqs > 0 ? args.num_tokens / args.num_seqs : args.num_tokens;
|
||||
const index_t kBlockQ_tiny = 16 / args.num_queries_per_kv; // kBlockQ for 1-warp 16x16 kernel
|
||||
|
||||
if(avg_q <= kBlockQ_tiny)
|
||||
return tile_tier::tiny; // pure decode: 1 warp, 16x16 MFMA, kBlockM=16
|
||||
|
||||
const index_t kBlockQ_small = 64 / args.num_queries_per_kv; // kBlockQ for 2-warp kernel
|
||||
if(avg_q <= kBlockQ_small)
|
||||
return tile_tier::small; // short decode: 2 warps, kBlockM=64
|
||||
|
||||
// 4-warp serial pipeline outperforms 8-warp interleaved on all prefill shapes
|
||||
// (verified by exhaustive sweep over 363 shapes from production trace).
|
||||
return tile_tier::medium; // all prefill: 4 warps, kBlockM=128
|
||||
}
|
||||
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
const bool is_mask = (args.mask_type != static_cast<int>(mask_enum::no_mask));
|
||||
const auto tier = select_tile_tier(args);
|
||||
|
||||
// d128, MHA (num_queries_per_kv == 1)
|
||||
if(args.hdim == 128 && args.num_queries_per_kv == 1)
|
||||
{
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 128, 256, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 128, 256, 1)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 128, 256, 1)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 128, 256, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// d64, GQA-8 (num_queries_per_kv == 8)
|
||||
if(args.hdim == 64 && args.num_queries_per_kv == 8)
|
||||
{
|
||||
const bool use_bs32 = (args.page_blk_size < 64);
|
||||
|
||||
if(tier == tile_tier::tiny)
|
||||
{
|
||||
if(use_bs32) {
|
||||
// bs32 narrow: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4.
|
||||
// Avoids 1-warp race condition; 2x less waste than small tier.
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::fp16, true, 64, 32, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, false, 64, 32, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW(unified_attention_args::data_type_enum::bf16, true, 64, 32, 8)
|
||||
}
|
||||
} else {
|
||||
// bs64 tiny: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2.
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(unified_attention_args::data_type_enum::fp16, false, 64, 16, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(unified_attention_args::data_type_enum::fp16, true, 64, 16, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(unified_attention_args::data_type_enum::bf16, false, 64, 16, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_TINY(unified_attention_args::data_type_enum::bf16, true, 64, 16, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(tier == tile_tier::small)
|
||||
{
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::fp16, true, 64, 64, 8)
|
||||
}
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, false, 64, 64, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL(unified_attention_args::data_type_enum::bf16, true, 64, 64, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(tier == tile_tier::medium)
|
||||
{
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::fp16, true, 64, 128, 8)
|
||||
}
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(use_bs32) {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
} else {
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, false, 64, 128, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM(unified_attention_args::data_type_enum::bf16, true, 64, 128, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Large prefill: 8 warps, kBlockM=256 (kBlockQ=32)
|
||||
// No bs32 variant -- NumIssues < 1 for 8-warp tier with block_size=32.
|
||||
if(args.data_type == unified_attention_args::data_type_enum::fp16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, false, 64, 256, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::fp16, true, 64, 256, 8)
|
||||
}
|
||||
else if(args.data_type == unified_attention_args::data_type_enum::bf16)
|
||||
{
|
||||
if(!is_mask) DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, false, 64, 256, 8)
|
||||
else DISPATCH_UNIFIED_ATTENTION(unified_attention_args::data_type_enum::bf16, true, 64, 256, 8)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cerr << "unified_attention: no matching kernel instance for hdim=" << args.hdim
|
||||
<< " num_queries_per_kv=" << args.num_queries_per_kv
|
||||
<< " data_type=" << args.data_type << " mask_type=" << args.mask_type << std::endl;
|
||||
return std::make_pair(false, -1.f);
|
||||
}
|
||||
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_BS32_NARROW
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM_BS32
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_TINY
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_SMALL
|
||||
#undef DISPATCH_UNIFIED_ATTENTION_DECODE_MEDIUM
|
||||
#undef DISPATCH_UNIFIED_ATTENTION
|
||||
|
||||
} // namespace ck_tile
|
||||
87
example/ck_tile/42_unified_attention/unified_attention.hpp
Normal file
87
example/ck_tile/42_unified_attention/unified_attention.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/ops/unified_attention.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct unified_attention_args
|
||||
{
|
||||
enum class data_type_enum
|
||||
{
|
||||
fp16,
|
||||
bf16
|
||||
};
|
||||
|
||||
data_type_enum data_type;
|
||||
// bool is_varlen;
|
||||
index_t mask_type; // should be 0 for no mask; or 2 for causal mask (window_size_left < 0 and
|
||||
// window_size_right == 0).
|
||||
|
||||
index_t num_tokens; // total number of tokens in query
|
||||
index_t num_blks;
|
||||
index_t num_head_q;
|
||||
index_t num_queries_per_kv;
|
||||
index_t page_blk_size;
|
||||
// index_t BLOCK_SIZE;
|
||||
|
||||
index_t hdim;
|
||||
// TODO window
|
||||
float scale_s;
|
||||
float scale;
|
||||
float scale_k;
|
||||
float scale_v;
|
||||
float scale_out;
|
||||
|
||||
const void* q_ptr;
|
||||
index_t query_stride_0;
|
||||
index_t query_stride_1;
|
||||
|
||||
const void* k_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
|
||||
index_t stride_k_cache_0;
|
||||
index_t stride_k_cache_1;
|
||||
index_t stride_k_cache_2;
|
||||
index_t stride_k_cache_3;
|
||||
|
||||
const void* v_ptr; // [num_blks, blk_size, num_kv_heads, head_size]
|
||||
index_t stride_v_cache_0;
|
||||
index_t stride_v_cache_1;
|
||||
index_t stride_v_cache_2;
|
||||
index_t stride_v_cache_3;
|
||||
|
||||
void* o_ptr;
|
||||
index_t output_stride_0;
|
||||
index_t output_stride_1;
|
||||
|
||||
const int32_t* block_tables_ptr;
|
||||
index_t block_table_stride;
|
||||
const int32_t* seq_lens_ptr; // seq len in each batch
|
||||
const int32_t* query_start_len_ptr; // [num_seqs+1]
|
||||
|
||||
index_t num_seqs; // number of batches for q
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream,
|
||||
const unified_attention_args::data_type_enum& data_type);
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
std::pair<bool, float> unified_attention(const unified_attention_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
struct UnifiedAttentionMasks
|
||||
{
|
||||
using NoMask = ck_tile::GenericAttentionMask<false>;
|
||||
using GenericMask = ck_tile::GenericAttentionMask<true, true>;
|
||||
using CausalMask = ck_tile::GenericAttentionMask<true, false>;
|
||||
};
|
||||
445
example/ck_tile/42_unified_attention/unified_attention_impl.hpp
Normal file
445
example/ck_tile/42_unified_attention/unified_attention_impl.hpp
Normal file
@@ -0,0 +1,445 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "ck_tile/core/numeric/bfloat16.hpp"
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/container/sequence.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
|
||||
#include "ck_tile/ops/unified_attention/block/block_masking.hpp"
|
||||
#include "ck_tile/ops/unified_attention/kernel/unified_attention_kernel.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_shape.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/tile_unified_attention_traits.hpp"
|
||||
|
||||
#include "unified_attention.hpp"
|
||||
#include "mask.hpp"
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair( \
|
||||
true, unified_attention_kernel_launch<kernel_traits::kernel>(args, config)); \
|
||||
}
|
||||
|
||||
#define INST_UNIFIED_ATTENTION_DISPATCH_DECODE(kernel_traits) \
|
||||
template <> \
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch_decode<kernel_traits>( \
|
||||
const unified_attention_args& args, const stream_config& config) \
|
||||
{ \
|
||||
return std::make_pair( \
|
||||
true, unified_attention_kernel_launch<kernel_traits::kernel, true>(args, config)); \
|
||||
}
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <unified_attention_args::data_type_enum DataType>
|
||||
struct unified_attention_problem_traits;
|
||||
|
||||
template <>
|
||||
struct unified_attention_problem_traits<unified_attention_args::data_type_enum::fp16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::half_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::half_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct unified_attention_problem_traits<unified_attention_args::data_type_enum::bf16>
|
||||
{
|
||||
using qkvp_dtype = ck_tile::bf16_t;
|
||||
using acc_dtype = float;
|
||||
using o_dtype = ck_tile::bf16_t;
|
||||
using lse_dtype = float;
|
||||
};
|
||||
|
||||
// Parameterized kernel traits: DataType, IsMasking, HeadSize, BlockM, NumQueriesPerKV, BlockSize
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 256,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 8 warps for warp specialization; kBlockM must be 8 * 32 = 256
|
||||
using unified_attention_block_warps = sequence<8, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true // IsVLayoutRowMajor
|
||||
>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, // kPadSeqLenQ_
|
||||
false, // kPadHeadDimQ
|
||||
-1 // kBlockPerCu
|
||||
>;
|
||||
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, /*IsLocal=*/false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, // kPadM
|
||||
true, // kPadM
|
||||
true // UseRawStore
|
||||
>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Decode-tuned traits: 4 warps (1 warp group), kBlockM=128, serial pipeline.
|
||||
// Uses the single-warp-group path in UnifiedAttentionPipeline.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 128,
|
||||
index_t BlockM_ = 128,
|
||||
index_t NumQPerKV_ = 1,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_decode_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
// kBlockM kBlockQ BLOCK_SIZE HEAD_SIZE
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 4 warps -> kBlockSize = 256 threads -> NumWarpGroups = 1
|
||||
using unified_attention_block_warps = sequence<4, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline = UnifiedAttentionPipeline<unified_attention_pipeline_problem>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Small decode traits: 2 warps, kBlockM=64, decode policy (NumWarpPerGroup=2).
|
||||
// Uses 1D warp layout (sequence<2,1,1>) so no softmax reduction changes needed.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 64,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_decode_small_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<32, 32, 16>;
|
||||
// 2 warps along M: kBlockM=2*32=64, kBlockSize=128, NumWarpGroups=1
|
||||
using unified_attention_block_warps = sequence<2, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineDecodePolicy>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// Tiny decode traits: 1 warp, 16x16 MFMA, kBlockM=16, kBlockQ=2 for GQA-8.
|
||||
// Matches Triton's BLOCK_M=16 / BLOCK_Q=2 decode configuration.
|
||||
// Uses block_tile_reduce_sync instead of permlane32_swap for 16x16 MFMA.
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 16,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = (HeadSize_ <= 64) ? 64 : 32>
|
||||
struct unified_attention_decode_tiny_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
|
||||
// 1 warp: kBlockM=1*16=16, kBlockSize=64, NumWarpGroups=1
|
||||
using unified_attention_block_warps = sequence<1, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineTinyDecodePolicy>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
// bs32 decode traits: 2 warps, 16x16 MFMA, kBlockM=32, kBlockQ=4 for GQA-8.
|
||||
// Used for block_size=32 decode: avoids the 1-warp pipeline race condition
|
||||
// and reduces query waste from 87.5% (small tier kBlockQ=8) to 75% (kBlockQ=4).
|
||||
template <unified_attention_args::data_type_enum DataType,
|
||||
bool IsMasking,
|
||||
index_t HeadSize_ = 64,
|
||||
index_t BlockM_ = 32,
|
||||
index_t NumQPerKV_ = 8,
|
||||
index_t BlockSize_ = 32>
|
||||
struct unified_attention_decode_bs32_kernel_traits
|
||||
{
|
||||
static constexpr auto date_type = DataType;
|
||||
static constexpr bool is_masking = IsMasking;
|
||||
|
||||
static constexpr index_t kBlockM = BlockM_;
|
||||
static constexpr index_t HEAD_SIZE = HeadSize_;
|
||||
static constexpr index_t BLOCK_SIZE = BlockSize_;
|
||||
|
||||
static constexpr index_t num_queries_per_kv = NumQPerKV_;
|
||||
static constexpr index_t kBlockQ = kBlockM / num_queries_per_kv;
|
||||
|
||||
using unified_attention_block_tile = sequence<kBlockM, kBlockQ, BLOCK_SIZE, HEAD_SIZE>;
|
||||
using unified_attention_warp_gemm_shape = sequence<16, 16, 32>;
|
||||
using unified_attention_block_warps = sequence<2, 1, 1>;
|
||||
|
||||
using unified_attention_shape = TileUnifiedAttentionShape<unified_attention_block_tile,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
unified_attention_block_warps,
|
||||
unified_attention_warp_gemm_shape,
|
||||
true>;
|
||||
|
||||
using unified_attention_traits = TileUnifiedAttentionTraits<true, false, -1>;
|
||||
using unified_attention_mask = GenericAttentionMask<IsMasking, false>;
|
||||
|
||||
using unified_attention_pipeline_problem = UnifiedAttentionPipelineProblem<
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::lse_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::qkvp_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
unified_attention_shape,
|
||||
unified_attention_mask,
|
||||
unified_attention_traits>;
|
||||
|
||||
using unified_attention_pipeline =
|
||||
UnifiedAttentionPipeline<unified_attention_pipeline_problem,
|
||||
UnifiedAttentionPipelineDecodePolicy>;
|
||||
|
||||
using epilogue = Default2DEpilogue<
|
||||
Default2DEpilogueProblem<typename unified_attention_problem_traits<date_type>::acc_dtype,
|
||||
typename unified_attention_problem_traits<date_type>::o_dtype,
|
||||
true, true, true>>;
|
||||
|
||||
using kernel = UnifiedAttentionKernel<unified_attention_pipeline, epilogue>;
|
||||
};
|
||||
|
||||
template <typename Kernel, bool UseDecodeGrid = false>
|
||||
float unified_attention_kernel_launch(const unified_attention_args& args,
|
||||
const stream_config& config)
|
||||
{
|
||||
constexpr index_t kBlockQ = Kernel::kBlockQ;
|
||||
index_t total_num_q_blocks = args.num_tokens / kBlockQ + args.num_seqs;
|
||||
auto kargs = Kernel::MakeKargs(args.q_ptr,
|
||||
args.k_ptr,
|
||||
args.v_ptr,
|
||||
args.o_ptr,
|
||||
args.num_blks,
|
||||
args.num_head_q,
|
||||
args.num_queries_per_kv,
|
||||
args.scale_s,
|
||||
args.scale,
|
||||
args.scale_k,
|
||||
args.scale_v,
|
||||
args.scale_out,
|
||||
args.page_blk_size,
|
||||
total_num_q_blocks,
|
||||
args.query_stride_0,
|
||||
args.query_stride_1,
|
||||
args.stride_k_cache_0,
|
||||
args.stride_k_cache_1,
|
||||
args.stride_k_cache_2,
|
||||
args.stride_k_cache_3,
|
||||
args.stride_v_cache_0,
|
||||
args.stride_v_cache_1,
|
||||
args.stride_v_cache_2,
|
||||
args.stride_v_cache_3,
|
||||
args.output_stride_0,
|
||||
args.output_stride_1,
|
||||
args.block_tables_ptr,
|
||||
args.block_table_stride,
|
||||
args.seq_lens_ptr,
|
||||
args.query_start_len_ptr,
|
||||
args.num_seqs);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(UseDecodeGrid)
|
||||
{
|
||||
grids = Kernel::GridSizeDecode(args.num_head_q / args.num_queries_per_kv, args.num_seqs);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize2D(args.num_head_q / args.num_queries_per_kv, total_num_q_blocks);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr index_t kBlockPerCu = Kernel::kBlockPerCu;
|
||||
|
||||
return launch_kernel(config, make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
|
||||
// return value:
|
||||
// first = whether the kernel was launched (true = launched, false = skipped)
|
||||
// second = elapsed time (ms) of the kernel launch, valid only if first == true
|
||||
template <typename KernelTraits>
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch(const unified_attention_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
template <typename KernelTraits>
|
||||
std::pair<bool, float> unified_attention_kernel_dispatch_decode(const unified_attention_args& args,
|
||||
const stream_config& config);
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user