mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Add codegen instances
The following examples have been tested for 04_codegen: ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 256 256 ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 64 64 ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 32 32 ./bin/codegen_basic_flash_attention_fwd 1 1 64 4096 4096 128 128 ./bin/codegen_basic_flash_attention_fwd 1 1 64 2048 2048 128 128 ./bin/codegen_basic_flash_attention_fwd 1 1 64 512 512 128 128
This commit is contained in:
@@ -4,7 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp"
|
||||
#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
@@ -36,25 +36,21 @@ int main(int argc, char* argv[])
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
[[maybe_unused]] ck_tile::index_t time_kernel = 0;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 3)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
verification = std::stoi(argv[2]);
|
||||
}
|
||||
|
||||
if(argc == 9)
|
||||
else if(argc == 8)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
Batch = std::stoi(argv[4]);
|
||||
M0 = std::stoi(argv[5]);
|
||||
N0 = std::stoi(argv[6]);
|
||||
K0 = std::stoi(argv[7]);
|
||||
N1 = std::stoi(argv[8]);
|
||||
verification = std::stoi(argv[2]);
|
||||
Batch = std::stoi(argv[3]);
|
||||
M0 = std::stoi(argv[4]);
|
||||
N0 = std::stoi(argv[5]);
|
||||
K0 = std::stoi(argv[6]);
|
||||
N1 = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
|
||||
|
||||
@@ -38,6 +38,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
@@ -46,7 +48,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
@@ -180,6 +182,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
|
||||
constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
|
||||
@@ -188,7 +192,7 @@ struct BlockGemmARegBSmemCRegV1
|
||||
KPerBlock == BlockGemmShape::kK,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, MPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
|
||||
@@ -10,10 +10,25 @@ namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
if constexpr (kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -10,10 +10,25 @@ namespace ck_tile {
|
||||
|
||||
struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
{
|
||||
template <typename Problem>
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
if constexpr (kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported configuration for warp execution.");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -62,11 +62,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
constexpr auto blockgemm = GetBlockGemm<Problem>();
|
||||
using BlockGemm = remove_cvref_t<decltype(blockgemm)>;
|
||||
|
||||
static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = AKDim;
|
||||
|
||||
constexpr auto config =
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp<Problem, kMPerBlock>();
|
||||
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
|
||||
@@ -36,25 +36,21 @@ int main(int argc, char* argv[])
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
[[maybe_unused]] ck_tile::index_t time_kernel = 0;
|
||||
|
||||
if(argc == 4)
|
||||
if(argc == 3)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
verification = std::stoi(argv[2]);
|
||||
}
|
||||
|
||||
if(argc == 9)
|
||||
else if(argc == 8)
|
||||
{
|
||||
init_method = std::stoi(argv[1]);
|
||||
time_kernel = std::stoi(argv[2]);
|
||||
verification = std::stoi(argv[3]);
|
||||
Batch = std::stoi(argv[4]);
|
||||
M0 = std::stoi(argv[5]);
|
||||
N0 = std::stoi(argv[6]);
|
||||
K0 = std::stoi(argv[7]);
|
||||
N1 = std::stoi(argv[8]);
|
||||
verification = std::stoi(argv[2]);
|
||||
Batch = std::stoi(argv[3]);
|
||||
M0 = std::stoi(argv[4]);
|
||||
N0 = std::stoi(argv[5]);
|
||||
K0 = std::stoi(argv[6]);
|
||||
N1 = std::stoi(argv[7]);
|
||||
}
|
||||
|
||||
std::array<ck_tile::index_t, 3> q_lengths{Batch, M0, K0};
|
||||
|
||||
@@ -11,16 +11,8 @@ import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
# def get_if_str(idx, total, last_else=True):
|
||||
# if idx == 0:
|
||||
# return 'if'
|
||||
# elif idx < total - 1:
|
||||
# return 'else if'
|
||||
# else:
|
||||
# return 'else' if last_else else 'else if'
|
||||
|
||||
def get_if_str(size_, total, last_else=True):
|
||||
if size_ == "small":
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
return 'if'
|
||||
else:
|
||||
return 'else if'
|
||||
@@ -39,13 +31,13 @@ template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kHeadDim_,
|
||||
index_t kM0PerBlock_,
|
||||
index_t kN0PerBlock_,
|
||||
index_t kK0PerBlock_,
|
||||
index_t kN1PerBlock_,
|
||||
index_t kK1PerBlock_>
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
@@ -62,7 +54,7 @@ struct flash_attention_fwd_traits_
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
};
|
||||
|
||||
@@ -70,13 +62,13 @@ template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize,
|
||||
ck_tile::index_t kHeadDim,
|
||||
ck_tile::index_t kM0PerBlock,
|
||||
ck_tile::index_t kN0PerBlock,
|
||||
ck_tile::index_t kK0PerBlock,
|
||||
ck_tile::index_t kN1PerBlock,
|
||||
ck_tile::index_t kK1PerBlock>
|
||||
ck_tile::index_t kBlockSize = 256,
|
||||
ck_tile::index_t kHeadDim = 128,
|
||||
ck_tile::index_t kM0PerBlock = 128,
|
||||
ck_tile::index_t kN0PerBlock = 128,
|
||||
ck_tile::index_t kK0PerBlock = 64,
|
||||
ck_tile::index_t kN1PerBlock = 128,
|
||||
ck_tile::index_t kK1PerBlock = 64>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
@@ -90,78 +82,6 @@ using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
kK1PerBlock>;
|
||||
"""
|
||||
|
||||
# API_COMMON_HEADER = """
|
||||
# // SPDX-License-Identifier: MIT
|
||||
# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
# #include <ck_tile/core.hpp>
|
||||
# #include "flash_attention_fwd.hpp"
|
||||
# #include <iostream>
|
||||
|
||||
# #pragma once
|
||||
|
||||
# using S = ck_tile::stream_config;
|
||||
# using A = FlashAttnArgs;
|
||||
|
||||
# {F_traits_define}
|
||||
|
||||
# template <typename QDataType,
|
||||
# typename KDataType,
|
||||
# typename VDataType,
|
||||
# typename ODataType,
|
||||
# typename Traits_>
|
||||
# float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
# const ck_tile::stream_config& stream_config) {{
|
||||
# using SaccDataType = typename Traits_::SaccDataType;
|
||||
# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
# using PDataType = typename Traits_::PDataType;
|
||||
# using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
# if(stream_config.log_level_ > 0)
|
||||
# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
|
||||
|
||||
# return ck_tile::launch_kernel(stream_config,
|
||||
# ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
|
||||
# ck_tile::FlashAttentionFwd<QDataType,
|
||||
# KDataType,
|
||||
# VDataType,
|
||||
# SaccDataType,
|
||||
# SMPLComputeDataType,
|
||||
# PDataType,
|
||||
# OaccDataType,
|
||||
# ODataType,
|
||||
# Traits_::kBlockSize,
|
||||
# Traits_::kHeadDim,
|
||||
# Traits_::kM0PerBlock,
|
||||
# Traits_::kN0PerBlock,
|
||||
# Traits_::kK0PerBlock,
|
||||
# Traits_::kN1PerBlock,
|
||||
# Traits_::kK1PerBlock>{{}},
|
||||
# kGridSize,
|
||||
# Traits_::kBlockSize,
|
||||
# 0,
|
||||
# a.q_ptr,
|
||||
# a.k_ptr,
|
||||
# a.v_ptr,
|
||||
# a.o_ptr,
|
||||
# a.M0,
|
||||
# a.N0,
|
||||
# a.K0,
|
||||
# a.N1,
|
||||
# a.Batch,
|
||||
# a.strideQ, // StrideQ
|
||||
# a.strideK, // StrideK
|
||||
# a.strideV, // StrideV
|
||||
# a.strideO, // StrideO
|
||||
# a.batchStrideQ, // BatchStrideQ
|
||||
# a.batchStrideK, // BatchStrideK
|
||||
# a.batchStrideV, // BatchStrideV
|
||||
# a.batchStrideO)); // BatchStrideO
|
||||
# }}
|
||||
# """
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
@@ -204,14 +124,6 @@ template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::ha
|
||||
}}
|
||||
"""
|
||||
|
||||
# API_PER_DTYPE = """ {F_if}(std::is_same_v<QDataType, {F_q_type}> && std::is_same_v<KDataType, {F_k_type}> && std::is_same_v<VDataType, {F_v_type}> && std::is_same_v<ODataType, {F_o_type}>) {{
|
||||
# {F_per_size_case}
|
||||
# }}
|
||||
# """
|
||||
# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{
|
||||
# {F_inner_dispatch}
|
||||
# }}
|
||||
# """
|
||||
API_INNER_CASE = """ {F_if} {F_VEC_COND}
|
||||
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
|
||||
"""
|
||||
@@ -320,13 +232,13 @@ template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_,
|
||||
index_t kHeadDim_,
|
||||
index_t kM0PerBlock_,
|
||||
index_t kN0PerBlock_,
|
||||
index_t kK0PerBlock_,
|
||||
index_t kN1PerBlock_,
|
||||
index_t kK1PerBlock_>
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
@@ -456,36 +368,28 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
len_in_size = sum(len(b.instance_list) for b in blob_per_size)
|
||||
|
||||
size_cond = ""
|
||||
if size_ == "small":
|
||||
size_cond = "(a.M0 < 2048 && a.N0 < 2048)"
|
||||
elif size_ == "medium":
|
||||
size_cond = "(a.M0 >= 2048 && a.N0 >= 2048 && a.M0 < 4096 && a.N0 < 4096)"
|
||||
else: # large
|
||||
size_cond = "(a.M0 >= 4096 || a.N0 >= 4096)"
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 256 && a.N1 == 256)"
|
||||
elif size_ == "head_dim_128_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_64_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 64 && a.N1 == 64)"
|
||||
elif size_ == "head_dim_32_seq_4096":
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 32 && a.N1 == 32)"
|
||||
elif size_ == "head_dim_128_seq_2048":
|
||||
size_cond = "(a.M0 <= 2048 && a.N0 <= 2048 && a.M0 > 512 && a.N0 > 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
elif size_ == "head_dim_128_seq_512":
|
||||
size_cond = "(a.M0 <= 512 && a.N0 <= 512 && a.K0 == 128 && a.N1 == 128)"
|
||||
else:
|
||||
size_cond = "(a.M0 <= 4096 && a.N0 <= 4096 && a.M0 > 2048 && a.N0 > 2048 && a.K0 == 128 && a.N1 == 128)"
|
||||
|
||||
inner_str += self.API_INNER_CASE.format(
|
||||
# F_if=get_if_str(idx_in_size, len_in_size, False),
|
||||
F_if=get_if_str(size_, len_in_size, False),
|
||||
F_VEC_COND=size_cond,
|
||||
F_trait_name=ins.trait_name
|
||||
)
|
||||
|
||||
# size_str += self.API_PER_SIZE_CASE.format(
|
||||
# F_if=get_if_str(i_size, len(blob_per_t)),
|
||||
# F_SIZE_COND=size_cond,
|
||||
# F_inner_dispatch=inner_str
|
||||
# )
|
||||
size_str += inner_str
|
||||
|
||||
# q_type, k_type, v_type, o_type = dtype_.split(',')
|
||||
# d_str += self.API_PER_DTYPE.format(
|
||||
# F_if=get_if_str(i_d, len(t_dtype_dict)),
|
||||
# F_q_type=DATA_TYPE_MAP[q_type],
|
||||
# F_k_type=DATA_TYPE_MAP[k_type],
|
||||
# F_v_type=DATA_TYPE_MAP[v_type],
|
||||
# F_o_type=DATA_TYPE_MAP[o_type],
|
||||
# F_per_size_case=size_str
|
||||
# )
|
||||
|
||||
d_str += size_str
|
||||
|
||||
api_base = self.API_BASE.format(
|
||||
@@ -500,18 +404,24 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
|
||||
# Define kernel configurations for different size categories
|
||||
trait_dict = {
|
||||
"small": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 128, 128, 128, 32, 128, 32),
|
||||
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 64, 32, 64, 32),
|
||||
"head_dim_256_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"medium": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 128, 128, 128, 32, 128, 32),
|
||||
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 256, 128, 32, 128, 32),
|
||||
"head_dim_128_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"head_dim_64_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 64, 64, 64, 64, 64, 64),
|
||||
],
|
||||
"head_dim_32_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 128, 32, 32, 32, 32, 32, 32),
|
||||
],
|
||||
"head_dim_128_seq_2048": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 64, 128, 64),
|
||||
],
|
||||
"head_dim_128_seq_512": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 128, 128, 128),
|
||||
],
|
||||
"large": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 128, 128, 128, 32, 128, 32),
|
||||
# h_traits('fp32', 'fp32', 'fp32', 'fp32', 512, 128, 256, 256, 32, 256, 32),
|
||||
]
|
||||
}
|
||||
|
||||
# Toy example only support fp16
|
||||
@@ -561,7 +471,7 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
blobs = self.get_blobs(args)
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
|
||||
Reference in New Issue
Block a user