Add codegen instances

The following examples have been tested for 04_codegen:

./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 256 256
./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 64 64
./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 32 32
./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 128 128
./bin/codegen_basic_flash_attention_fwd 1 1 64 2048 2048 128 128
./bin/codegen_basic_flash_attention_fwd 1 1 64 512 512 128 128
This commit is contained in:
Clement Lin
2025-04-23 11:48:06 +08:00
parent 068d9fdbf7
commit 35de33c57b
8 changed files with 116 additions and 179 deletions

View File

@@ -4,7 +4,6 @@
#pragma once
#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp"
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
namespace ck_tile {

View File

@@ -36,25 +36,21 @@ int main(int argc, char* argv[])
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
[[maybe_unused]] ck_tile::index_t time_kernel = 0;
if(argc == 4)
if(argc == 3)
{
init_method = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
verification = std::stoi(argv[3]);
verification = std::stoi(argv[2]);
}
if(argc == 9)
else if(argc == 8)
{
init_method = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
verification = std::stoi(argv[3]);
Batch = std::stoi(argv[4]);
M0 = std::stoi(argv[5]);
N0 = std::stoi(argv[6]);
K0 = std::stoi(argv[7]);
N1 = std::stoi(argv[8]);
verification = std::stoi(argv[2]);
Batch = std::stoi(argv[3]);
M0 = std::stoi(argv[4]);
N0 = std::stoi(argv[5]);
K0 = std::stoi(argv[6]);
N1 = std::stoi(argv[7]);
}
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};

View File

@@ -38,6 +38,8 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
@@ -46,7 +48,7 @@ struct BlockGemmARegBSmemCRegV1
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
@@ -180,6 +182,8 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
"wrong!");
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
@@ -188,7 +192,7 @@ struct BlockGemmARegBSmemCRegV1
KPerBlock == BlockGemmShape::kK,
"wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;

View File

@@ -10,10 +10,25 @@ namespace ck_tile {
struct BlockGemmARegBSmemCRegV1DefaultPolicy
{
template <typename Problem>
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
if constexpr (kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
else if constexpr (kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
}
else if constexpr (kM0 == 128)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
else
{
static_assert(false, "Unsupported configuration for warp execution.");
}
}
};

View File

@@ -10,10 +10,25 @@ namespace ck_tile {
struct BlockGemmARegBSmemCRegV1K8Policy
{
template <typename Problem>
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
if constexpr (kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
else if constexpr (kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
}
else if constexpr (kM0 == 128)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
}
else
{
static_assert(false, "Unsupported configuration for warp execution.");
}
}
};

View File

@@ -62,11 +62,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
constexpr auto blockgemm = GetBlockGemm<Problem>();
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = AKDim;
constexpr auto config =
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;

View File

@@ -36,25 +36,21 @@ int main(int argc, char* argv[])
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
[[maybe_unused]] ck_tile::index_t time_kernel = 0;
if(argc == 4)
if(argc == 3)
{
init_method = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
verification = std::stoi(argv[3]);
verification = std::stoi(argv[2]);
}
if(argc == 9)
else if(argc == 8)
{
init_method = std::stoi(argv[1]);
time_kernel = std::stoi(argv[2]);
verification = std::stoi(argv[3]);
Batch = std::stoi(argv[4]);
M0 = std::stoi(argv[5]);
N0 = std::stoi(argv[6]);
K0 = std::stoi(argv[7]);
N1 = std::stoi(argv[8]);
verification = std::stoi(argv[2]);
Batch = std::stoi(argv[3]);
M0 = std::stoi(argv[4]);
N0 = std::stoi(argv[5]);
K0 = std::stoi(argv[6]);
N1 = std::stoi(argv[7]);
}
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};

View File

@@ -11,16 +11,8 @@ import itertools
import copy
from dataclasses import dataclass
# def get_if_str(idx, total, last_else=True):
# if idx == 0:
# return 'if'
# elif idx < total - 1:
# return 'else if'
# else:
# return 'else' if last_else else 'else if'
def get_if_str(size_, total, last_else=True):
if size_ == "small":
if size_ == "head_dim_256_seq_4096":
return 'if'
else:
return 'else if'
@@ -39,13 +31,13 @@ template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
index_t kBlockSize_,
index_t kHeadDim_,
index_t kM0PerBlock_,
index_t kN0PerBlock_,
index_t kK0PerBlock_,
index_t kN1PerBlock_,
index_t kK1PerBlock_>
index_t kBlockSize_ = 256,
index_t kHeadDim_ = 128,
index_t kM0PerBlock_ = 128,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
@@ -62,7 +54,7 @@ struct flash_attention_fwd_traits_
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
};
@@ -70,13 +62,13 @@ template <typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
ck_tile::index_t kBlockSize,
ck_tile::index_t kHeadDim,
ck_tile::index_t kM0PerBlock,
ck_tile::index_t kN0PerBlock,
ck_tile::index_t kK0PerBlock,
ck_tile::index_t kN1PerBlock,
ck_tile::index_t kK1PerBlock>
ck_tile::index_t kBlockSize = 256,
ck_tile::index_t kHeadDim = 128,
ck_tile::index_t kM0PerBlock = 128,
ck_tile::index_t kN0PerBlock = 128,
ck_tile::index_t kK0PerBlock = 64,
ck_tile::index_t kN1PerBlock = 128,
ck_tile::index_t kK1PerBlock = 64>
using traits_ = flash_attention_fwd_traits_<SaccDataType,
SMPLComputeDataType,
PDataType,
@@ -90,78 +82,6 @@ using traits_ = flash_attention_fwd_traits_<SaccDataType,
kK1PerBlock>;
"""
# API_COMMON_HEADER = """
# // SPDX-License-Identifier: MIT
# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
# #include <ck_tile/core.hpp>
# #include "flash_attention_fwd.hpp"
# #include <iostream>
# #pragma once
# using S = ck_tile::stream_config;
# using A = FlashAttnArgs;
# {F_traits_define}
# template <typename QDataType,
# typename KDataType,
# typename VDataType,
# typename ODataType,
# typename Traits_>
# float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
# const ck_tile::stream_config& stream_config) {{
# using SaccDataType = typename Traits_::SaccDataType;
# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
# using PDataType = typename Traits_::PDataType;
# using OaccDataType = typename Traits_::OaccDataType;
# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
# if(stream_config.log_level_ > 0)
# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
# return ck_tile::launch_kernel(stream_config,
# ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
# ck_tile::FlashAttentionFwd<QDataType,
# KDataType,
# VDataType,
# SaccDataType,
# SMPLComputeDataType,
# PDataType,
# OaccDataType,
# ODataType,
# Traits_::kBlockSize,
# Traits_::kHeadDim,
# Traits_::kM0PerBlock,
# Traits_::kN0PerBlock,
# Traits_::kK0PerBlock,
# Traits_::kN1PerBlock,
# Traits_::kK1PerBlock>{{}},
# kGridSize,
# Traits_::kBlockSize,
# 0,
# a.q_ptr,
# a.k_ptr,
# a.v_ptr,
# a.o_ptr,
# a.M0,
# a.N0,
# a.K0,
# a.N1,
# a.Batch,
# a.strideQ, // StrideQ
# a.strideK, // StrideK
# a.strideV, // StrideV
# a.strideO, // StrideO
# a.batchStrideQ, // BatchStrideQ
# a.batchStrideK, // BatchStrideK
# a.batchStrideV, // BatchStrideV
# a.batchStrideO)); // BatchStrideO
# }}
# """
API_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
@@ -204,14 +124,6 @@ template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::ha
}}
"""
# API_PER_DTYPE = """ {F_if}(std::is_same_v<QDataType, {F_q_type}> && std::is_same_v<KDataType, {F_k_type}> && std::is_same_v<VDataType, {F_v_type}> && std::is_same_v<ODataType, {F_o_type}>) {{
# {F_per_size_case}
# }}
# """
# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{
# {F_inner_dispatch}
# }}
# """
API_INNER_CASE = """ {F_if} {F_VEC_COND}
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
"""
@@ -320,13 +232,13 @@ template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
index_t kBlockSize_,
index_t kHeadDim_,
index_t kM0PerBlock_,
index_t kN0PerBlock_,
index_t kK0PerBlock_,
index_t kN1PerBlock_,
index_t kK1PerBlock_>
index_t kBlockSize_ = 256,
index_t kHeadDim_ = 128,
index_t kM0PerBlock_ = 128,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
@@ -456,36 +368,28 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
size_cond = ""
if size_ == "small":
size_cond = "(a.M0 < 2048 && a.N0 < 2048)"
elif size_ == "medium":
size_cond = "(a.M0 >= 2048 && a.N0 >= 2048 && a.M0 < 4096 && a.N0 < 4096)"
else: # large
size_cond = "(a.M0 >= 4096 || a.N0 >= 4096)"
if size_ == "head_dim_256_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
elif size_ == "head_dim_128_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_64_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 64 && a.N1 == 64)"
elif size_ == "head_dim_32_seq_4096":
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
elif size_ == "head_dim_128_seq_2048":
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
elif size_ == "head_dim_128_seq_512":
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
else:
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
inner_str += self.API_INNER_CASE.format(
# F_if=get_if_str(idx_in_size, len_in_size, False),
F_if=get_if_str(size_, len_in_size, False),
F_VEC_COND=size_cond,
F_trait_name=ins.trait_name
)
# size_str += self.API_PER_SIZE_CASE.format(
# F_if=get_if_str(i_size, len(blob_per_t)),
# F_SIZE_COND=size_cond,
# F_inner_dispatch=inner_str
# )
size_str += inner_str
# q_type, k_type, v_type, o_type = dtype_.split(',')
# d_str += self.API_PER_DTYPE.format(
# F_if=get_if_str(i_d, len(t_dtype_dict)),
# F_q_type=DATA_TYPE_MAP[q_type],
# F_k_type=DATA_TYPE_MAP[k_type],
# F_v_type=DATA_TYPE_MAP[v_type],
# F_o_type=DATA_TYPE_MAP[o_type],
# F_per_size_case=size_str
# )
d_str += size_str
api_base = self.API_BASE.format(
@@ -500,18 +404,24 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
# Define kernel configurations for different size categories
trait_dict = {
"small": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 128, 128, 128, 32, 128, 32),
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 64, 32, 64, 32),
"head_dim_256_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
],
"medium": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 128, 128, 128, 32, 128, 32),
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 256, 128, 32, 128, 32),
"head_dim_128_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
],
"head_dim_64_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
],
"head_dim_32_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
],
"head_dim_128_seq_2048": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
],
"head_dim_128_seq_512": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 128, 128, 128),
],
"large": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 512, 128, 256, 256, 32, 256, 32),
]
}
# Toy example only support fp16
@@ -561,7 +471,7 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
blobs = self.get_blobs(args)
for b in blobs:
(w_p / (b.name + ".cpp")).write_text(b.content)
def list_blobs(args):
api_list = args.api.split(',')
for api in api_list: