Fix clang-format

This commit is contained in:
MHYang
2025-04-24 09:09:18 +00:00
parent c4b2d5074a
commit 0e6a23258e
17 changed files with 408 additions and 397 deletions

View File

@@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1
// B block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
constexpr auto config =
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
@@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1
return b_block_dstr_encode;
}
static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
@@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp>
__device__ void operator() (CBlockTensor& c_block_tensor,
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BLdsTile& b_block_tensor_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];

View File

@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr (kM0 == 64)
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
else if constexpr (kM0 == 32)
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
}
else if constexpr (kM0 == 128)
else if constexpr(kM0 == 128)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}

View File

@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr (kM0 == 64)
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
else if constexpr (kM0 == 32)
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
}
else if constexpr (kM0 == 128)
else if constexpr(kM0 == 128)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
}

View File

@@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
// Acc register tile
@@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
b_lds_gemm_window)){};
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
#if !defined(TOY_FA_FWD_OPT)
@@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
});
#else
@@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
bWarpTile);
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
bWarpTile);
block_sync_lds();
@@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
bWarpTile);
bWarpTile);
block_sync_lds();
}
@@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
bWarpTile);
bWarpTile);
}
#endif
return c_block_tile;

View File

@@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
{
@@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
return a_block_dstr;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
return MakeARegBlockDescriptor<Problem>();
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
@@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
{

View File

@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
using OaccDataType = float;
using ODataType = ck_tile::half_t;
ck_tile::index_t Batch = 64; // Batch Number * Head Number
ck_tile::index_t M0 = 4096; // SequenceLengthQ
ck_tile::index_t N0 = 4096; // SequencelengthK
ck_tile::index_t K0 = 128; // HeadDim
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
ck_tile::index_t Batch = 64; // Batch Number * Head Number
ck_tile::index_t M0 = 4096; // SequenceLengthQ
ck_tile::index_t N0 = 4096; // SequencelengthK
ck_tile::index_t K0 = 128; // HeadDim
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
if(argc == 3)
{

View File

@@ -12,7 +12,6 @@
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "flash_attention_fwd_impl.hpp"
namespace ck_tile {
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)

View File

@@ -14,7 +14,6 @@
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "tile_gemm_shape.hpp"
namespace ck_tile {
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
@@ -196,14 +195,16 @@ struct FlashAttentionFwdImpl
// V LDS tile window for store
auto v_copy_lds_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
v_dram_window.get_tile_distribution());
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
v_dram_window.get_tile_distribution());
// V LDS tile for block GEMM
auto v_lds_gemm_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0},
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
auto v_lds_gemm_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
#else
auto v_lds_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
@@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl
store_tile(v_lds_window, v);
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
block_sync_lds();
});
#else
@@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl
move_tile_window(v_dram_window, {0, kK1PerBlock});
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
vWarpTile);
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1.template HotLoopScheduler<8, 4>();
@@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl
}
// tail
{
if constexpr (k1_loops > 1)
if constexpr(k1_loops > 1)
{
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
vWarpTile);
get_slice_tile(p,
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
}
store_tile(v_copy_lds_window, v_prefetch);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
sequence<kM0PerBlock, kN0PerBlock>{}),
vWarpTile);
get_slice_tile(p,
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
sequence<kM0PerBlock, kN0PerBlock>{}),
vWarpTile);
block_sync_lds();
}
#endif

View File

@@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1
// B block tile distribution for load from lds
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
{
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
constexpr auto config =
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
using WG = remove_cvref_t<decltype(config.template get<0>())>;
constexpr index_t MWarp = config.template get<1>();
@@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1
return b_block_dstr_encode;
}
static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
static constexpr auto BLdsTileDistr =
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
@@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1
// C += A * B
template <typename CBlockTensor, typename ABlockTensorTmp>
__device__ void operator() (CBlockTensor& c_block_tensor,
__device__ void operator()(CBlockTensor& c_block_tensor,
const ABlockTensorTmp& a_block_tensor_tmp,
const BLdsTile& b_block_tensor_tmp) const
{
static_assert(
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];

View File

@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr (kM0 == 64)
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
}
else if constexpr (kM0 == 32)
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
}
else if constexpr (kM0 == 128)
else if constexpr(kM0 == 128)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}

View File

@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy
template <typename Problem, index_t kM0>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr (kM0 == 64)
if constexpr(kM0 == 64)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
}
else if constexpr (kM0 == 32)
else if constexpr(kM0 == 32)
{
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
}
else if constexpr (kM0 == 128)
else if constexpr(kM0 == 128)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
}

View File

@@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
// B LDS tile for block GEMM
auto b_lds_gemm_window = make_tile_window(
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
b_lds_block,
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
// Acc register tile
@@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
b_lds_gemm_window)){};
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
#if !defined(TOY_FA_FWD_OPT)
@@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
store_tile(b_copy_lds_window, b_block_tile);
block_sync_lds();
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
b_copy_lds_window);
block_sync_lds();
});
#else
@@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
block_gemm(c_block_tile,
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
bWarpTile);
get_slice_tile(a_copy_reg_tensor,
sequence<0, i_k0 * kKPerBlock>{},
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
bWarpTile);
block_sync_lds();
@@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 2) * kKPerBlock>{},
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
bWarpTile);
bWarpTile);
block_sync_lds();
}
@@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
get_slice_tile(a_copy_reg_tensor,
sequence<0, (k_loops - 1) * kKPerBlock>{},
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
bWarpTile);
bWarpTile);
}
#endif
return c_block_tile;

View File

@@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
{
@@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
return a_block_dstr;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeADramTileDistribution()
{
return MakeARegBlockDescriptor<Problem>();
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
@@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(make_unmerge_transform(
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<kKPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
template <typename Problem>
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
{

View File

@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
using OaccDataType = float;
using ODataType = ck_tile::half_t;
ck_tile::index_t Batch = 64; // Batch Number * Head Number
ck_tile::index_t M0 = 4096; // SequenceLengthQ
ck_tile::index_t N0 = 4096; // SequencelengthK
ck_tile::index_t K0 = 128; // HeadDim
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
ck_tile::index_t Batch = 64; // Batch Number * Head Number
ck_tile::index_t M0 = 4096; // SequenceLengthQ
ck_tile::index_t N0 = 4096; // SequencelengthK
ck_tile::index_t K0 = 128; // HeadDim
ck_tile::index_t N1 = 128; // HeadDim
ck_tile::index_t verification = 0;
ck_tile::index_t init_method = 1;
if(argc == 3)
{

View File

@@ -155,7 +155,6 @@ struct FlashAttentionFwd
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) %
num_tile_n1 * kN1PerBlock);
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
KDataType,
VDataType,

View File

@@ -14,7 +14,6 @@
#include "block_gemm_areg_bsmem_creg_v1.hpp"
#include "tile_gemm_shape.hpp"
namespace ck_tile {
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
@@ -196,14 +195,16 @@ struct FlashAttentionFwdImpl
// V LDS tile window for store
auto v_copy_lds_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
v_dram_window.get_tile_distribution());
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
v_dram_window.get_tile_distribution());
// V LDS tile for block GEMM
auto v_lds_gemm_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0},
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
auto v_lds_gemm_window =
make_tile_window(v_lds,
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
{0, 0},
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
#else
auto v_lds_window = make_tile_window(
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
@@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl
store_tile(v_lds_window, v);
block_sync_lds();
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
v_lds_window);
block_sync_lds();
});
#else
@@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl
move_tile_window(v_dram_window, {0, kK1PerBlock});
gemm1(o_acc,
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
vWarpTile);
get_slice_tile(p,
sequence<0, i_k1 * kK1PerBlock>{},
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1.template HotLoopScheduler<8, 4>();
@@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl
}
// tail
{
if constexpr (k1_loops > 1)
if constexpr(k1_loops > 1)
{
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
vWarpTile);
get_slice_tile(p,
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
vWarpTile);
block_sync_lds();
}
store_tile(v_copy_lds_window, v_prefetch);
block_sync_lds();
vWarpTile = load_tile(v_lds_gemm_window);
gemm1(o_acc,
get_slice_tile(p,
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
sequence<kM0PerBlock, kN0PerBlock>{}),
vWarpTile);
get_slice_tile(p,
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
sequence<kM0PerBlock, kN0PerBlock>{}),
vWarpTile);
block_sync_lds();
}
#endif

View File

@@ -1,33 +1,18 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#SPDX - License - Identifier : MIT
#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved.
import argparse
from enum import IntEnum
from pathlib import Path
import sys
from typing import List, Optional, Any
import functools
import itertools
import copy
from dataclasses import dataclass
import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass
def get_if_str(size_, total, last_else=True):
if size_ == "head_dim_256_seq_4096":
return 'if'
else:
return 'else if'
def get_if_str(size_, total, last_else = True) : if size_ == "head_dim_256_seq_4096" : return 'if' else : return 'else if'
DATA_TYPE_MAP = {'fp32': 'float',
'fp16': 'ck_tile::half_t',
'bf16': 'ck_tile::bf16_t'}
DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::half_t', 'bf16' : 'ck_tile::bf16_t' }
def BOOL_MAP(b_) -> str:
return 'true' if b_ else 'false'
def BOOL_MAP(b_)->str: return 'true' if b_ else 'false'
class FlashAttentionFwdCodegen:
API_TRAITS_DEFINE = """
class FlashAttentionFwdCodegen:API_TRAITS_DEFINE = ""
"
template <typename SaccDataType_,
template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
@@ -37,166 +22,184 @@ template <typename SaccDataType_,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_{using SaccDataType = ck_tile::remove_cvref_t <SaccDataType_>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kHeadDim = kHeadDim_;
static constexpr index_t kM0PerBlock = kM0PerBlock_;
static constexpr index_t kN0PerBlock = kN0PerBlock_;
static constexpr index_t kK0PerBlock = kK0PerBlock_;
static constexpr index_t kN1PerBlock = kN1PerBlock_;
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kHeadDim = kHeadDim_;
static constexpr index_t kM0PerBlock = kM0PerBlock_;
static constexpr index_t kN0PerBlock = kN0PerBlock_;
static constexpr index_t kK0PerBlock = kK0PerBlock_;
static constexpr index_t kN1PerBlock = kN1PerBlock_;
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
};
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
}
;
template <typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
ck_tile::index_t kBlockSize = 256,
ck_tile::index_t kHeadDim = 128,
ck_tile::index_t kBlockSize = 256,
ck_tile::index_t kHeadDim = 128,
ck_tile::index_t kM0PerBlock = 128,
ck_tile::index_t kN0PerBlock = 128,
ck_tile::index_t kK0PerBlock = 64,
ck_tile::index_t kN1PerBlock = 128,
ck_tile::index_t kK1PerBlock = 64>
using traits_ = flash_attention_fwd_traits_<SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>;
"""
SMPLComputeDataType,
PDataType,
OaccDataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>;
""
"
API_BASE = """
API_BASE = ""
"
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <ck_tile/core.hpp>
#include "flash_attention_fwd.hpp"
namespace ck_tile {{
namespace ck_tile
{
{
{F_traits_define}
{
F_traits_define
}
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename Traits_>
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config);
// Note: this internal API only declare, not define here, otherwise will block `make -j`
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename Traits_>
float flash_attention_fwd_(
const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config);
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType>
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config) {{
float r = -1;
{F_dispatch}
return r;
}}
template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, float, float, ck_tile::half_t, float, ck_tile::half_t>(
const FlashAttnArgs<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t>&,
const ck_tile::stream_config&);
}}
"""
API_INNER_CASE = """ {F_if} {F_VEC_COND}
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
"""
INSTANCE_BASE = """
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "flash_attention_fwd_api_common.hpp"
namespace ck_tile {
// clang-format off
//
{F_instance_def}
// clang-format on
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType>
float flash_attention_fwd(
const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config)
{
{
float r = -1;
{
F_dispatch
}
return r;
}
}
template float flash_attention_fwd<ck_tile::half_t,
ck_tile::half_t,
ck_tile::half_t,
float,
float,
ck_tile::half_t,
float,
ck_tile::half_t>(const FlashAttnArgs<ck_tile::half_t,
ck_tile::half_t,
ck_tile::half_t,
ck_tile::half_t>&,
const ck_tile::stream_config&);
}
}
"""
""
"
def __init__(self, working_path, kernel_filter):
self.working_path = working_path
self.kernel_filter = kernel_filter
API_INNER_CASE = ""
" {F_if} {F_VEC_COND}
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(
a, stream_config);
""
"
@dataclass
class h_traits:
F_SaccDataType: str
F_SMPLComputeDataType: str
F_PDataType: str
F_OaccDataType: str
F_kBlockSize: int
F_kHeadDim: int
F_kM0PerBlock: int
F_kN0PerBlock: int
F_kK0PerBlock: int
F_kN1PerBlock: int
F_kK1PerBlock: int
@property
def trait_name(self) -> str:
return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, "
f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, "
f"{DATA_TYPE_MAP[self.F_PDataType]}, "
f"{DATA_TYPE_MAP[self.F_OaccDataType]}, "
f"{self.F_kBlockSize}, {self.F_kHeadDim}, "
f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, "
f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}")
@property
def def_name(self) -> str:
return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, "
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, "
f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, "
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, "
"const ck_tile::stream_config&);")
@dataclass
class h_instance:
F_DataTypePair: str # "q,k,v,o"
F_SizeCategory: str # "small", "medium", "large"
instance_list: List[Any] # List[h_traits]
INSTANCE_BASE = """
INSTANCE_BASE = ""
"
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "flash_attention_fwd_api_common.hpp"
namespace ck_tile {{
// clang-format off
namespace ck_tile
{
// clang-format off
//
{F_instance_def}
// clang-format on
}}
// clang-format on
}
""
"
def
__init__(self, working_path, kernel_filter)
: self.working_path = working_path self.kernel_filter = kernel_filter
@dataclass class h_traits
: F_SaccDataType : str F_SMPLComputeDataType : str F_PDataType : str F_OaccDataType
: str F_kBlockSize : int F_kHeadDim : int F_kM0PerBlock : int F_kN0PerBlock : int F_kK0PerBlock
: int F_kN1PerBlock : int F_kK1PerBlock : int
@property def trait_name(self)
->str
: return (f "{DATA_TYPE_MAP[self.F_SaccDataType]}, " f
"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " f
"{DATA_TYPE_MAP[self.F_PDataType]}, " f "{DATA_TYPE_MAP[self.F_OaccDataType]}, " f
"{self.F_kBlockSize}, {self.F_kHeadDim}, " f
"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " f
"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}")
@property def def_name(self)
->str
: return (f "template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " f
"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " f
"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " f
"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, "
"const ck_tile::stream_config&);")
@dataclass class h_instance : F_DataTypePair : str #"q,k,v,o" F_SizeCategory : str
#"small",
"medium",
"large" instance_list : List[Any] #List[h_traits]
INSTANCE_BASE = ""
"
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "flash_attention_fwd_api_common.hpp"
namespace ck_tile
{
{
// clang-format off
//
{F_instance_def}
// clang-format on
}}
"""
@property
@@ -226,123 +229,133 @@ namespace ck_tile {{
#include "flash_attention_fwd.hpp"
namespace ck_tile {{
namespace ck_tile
{
{
template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
index_t kBlockSize_ = 256,
index_t kHeadDim_ = 128,
index_t kM0PerBlock_ = 128,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
template <typename SaccDataType_,
typename SMPLComputeDataType_,
typename PDataType_,
typename OaccDataType_,
index_t kBlockSize_ = 256,
index_t kHeadDim_ = 128,
index_t kM0PerBlock_ = 128,
index_t kN0PerBlock_ = 128,
index_t kK0PerBlock_ = 64,
index_t kN1PerBlock_ = 128,
index_t kK1PerBlock_ = 64>
struct flash_attention_fwd_traits_
{
{
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kHeadDim = kHeadDim_;
static constexpr index_t kM0PerBlock = kM0PerBlock_;
static constexpr index_t kN0PerBlock = kN0PerBlock_;
static constexpr index_t kK0PerBlock = kK0PerBlock_;
static constexpr index_t kN1PerBlock = kN1PerBlock_;
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kHeadDim = kHeadDim_;
static constexpr index_t kM0PerBlock = kM0PerBlock_;
static constexpr index_t kN0PerBlock = kN0PerBlock_;
static constexpr index_t kK0PerBlock = kK0PerBlock_;
static constexpr index_t kN1PerBlock = kN1PerBlock_;
static constexpr index_t kK1PerBlock = kK1PerBlock_;
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
}};
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
}
};
template <typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
ck_tile::index_t kBlockSize,
ck_tile::index_t kHeadDim,
ck_tile::index_t kM0PerBlock,
ck_tile::index_t kN0PerBlock,
ck_tile::index_t kK0PerBlock,
ck_tile::index_t kN1PerBlock,
ck_tile::index_t kK1PerBlock>
using traits_ = flash_attention_fwd_traits_<SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>;
template <typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
ck_tile::index_t kBlockSize,
ck_tile::index_t kHeadDim,
ck_tile::index_t kM0PerBlock,
ck_tile::index_t kN0PerBlock,
ck_tile::index_t kK0PerBlock,
ck_tile::index_t kN1PerBlock,
ck_tile::index_t kK1PerBlock>
using traits_ = flash_attention_fwd_traits_<SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>;
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename Traits_>
float flash_attention_fwd_(
const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config)
{
{
using SaccDataType = typename Traits_::SaccDataType;
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
using PDataType = typename Traits_::PDataType;
using OaccDataType = typename Traits_::OaccDataType;
index_t kGridSize =
a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType,
typename Traits_>
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config) {{
using SaccDataType = typename Traits_::SaccDataType;
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
using PDataType = typename Traits_::PDataType;
using OaccDataType = typename Traits_::OaccDataType;
if(stream_config.log_level_ > 0)
std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << ","
<< Traits_::kHeadDim << ">" << std::flush;
index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
if(stream_config.log_level_ > 0)
std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
return ck_tile::launch_kernel(stream_config,
ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
ck_tile::FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
Traits_::kBlockSize,
Traits_::kHeadDim,
Traits_::kM0PerBlock,
Traits_::kN0PerBlock,
Traits_::kK0PerBlock,
Traits_::kN1PerBlock,
Traits_::kK1PerBlock>{{}},
kGridSize,
Traits_::kBlockSize,
0,
a.q_ptr,
a.k_ptr,
a.v_ptr,
a.o_ptr,
a.M0,
a.N0,
a.K0,
a.N1,
a.Batch,
a.strideQ, // StrideQ
a.strideK, // StrideK
a.strideV, // StrideV
a.strideO, // StrideO
a.batchStrideQ, // BatchStrideQ
a.batchStrideK, // BatchStrideK
a.batchStrideV, // BatchStrideV
a.batchStrideO)); // BatchStrideO
}}
}}
return ck_tile::launch_kernel(
stream_config,
ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
ck_tile::FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
Traits_::kBlockSize,
Traits_::kHeadDim,
Traits_::kM0PerBlock,
Traits_::kN0PerBlock,
Traits_::kK0PerBlock,
Traits_::kN1PerBlock,
Traits_::kK1PerBlock>{{}},
kGridSize,
Traits_::kBlockSize,
0,
a.q_ptr,
a.k_ptr,
a.v_ptr,
a.o_ptr,
a.M0,
a.N0,
a.K0,
a.N1,
a.Batch,
a.strideQ, // StrideQ
a.strideK, // StrideK
a.strideV, // StrideV
a.strideO, // StrideO
a.batchStrideQ, // BatchStrideQ
a.batchStrideK, // BatchStrideK
a.batchStrideV, // BatchStrideV
a.batchStrideO)); // BatchStrideO
}
}
}
}
"""
def content_api(self, args) -> str:
# Sort based on dtype
#Sort based on dtype
t_dtype_dict = {}
blobs = self.get_blobs(args)
@@ -402,7 +415,7 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
h_traits = self.h_traits
h_instance = self.h_instance
# Define kernel configurations for different size categories
#Define kernel configurations for different size categories
trait_dict = {
"head_dim_256_seq_4096": [
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
@@ -424,17 +437,17 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
],
}
# Toy example only support fp16
#Toy example only support fp16
dtype_combinations = [
"fp16,fp16,fp16,fp16"
# "bf16,bf16,bf16,bf16"
#"bf16,bf16,bf16,bf16"
]
total_blob = []
for dtype_pair in dtype_combinations:
for size_category in trait_dict:
traits = trait_dict[size_category]
# Convert data types for the current dtype_pair
#Convert data types for the current dtype_pair
q_type, k_type, v_type, o_type = dtype_pair.split(',')
current_traits = []
for t in traits:
@@ -455,10 +468,10 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
blobs = self.get_blobs(args)
with list_p.open('w') as list_f:
# API related files
#API related files
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
# Kernel instance files
#Kernel instance files
for b in blobs:
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")