mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
Fix clang-format
This commit is contained in:
@@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
constexpr auto config =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
@@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
@@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator() (CBlockTensor& c_block_tensor,
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
|
||||
@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr (kM0 == 64)
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 32)
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 128)
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr (kM0 == 64)
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 32)
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 128)
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
@@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
@@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
@@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
bWarpTile);
|
||||
}
|
||||
#endif
|
||||
return c_block_tile;
|
||||
|
||||
@@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
@@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
@@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
|
||||
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
|
||||
if(argc == 3)
|
||||
{
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "flash_attention_fwd_impl.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
@@ -196,14 +195,16 @@ struct FlashAttentionFwdImpl
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
auto v_lds_gemm_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
@@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
@@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
@@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr (k1_loops > 1)
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
// B block tile distribution for load from lds
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
constexpr auto config =
|
||||
Policy::template GetWarpGemmMWarpNWarp<Problem, Problem::BlockGemmShape::kM>();
|
||||
using WG = remove_cvref_t<decltype(config.template get<0>())>;
|
||||
|
||||
constexpr index_t MWarp = config.template get<1>();
|
||||
@@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
|
||||
static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
static constexpr auto BLdsTileDistr =
|
||||
decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){};
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
template <index_t VectorSizeB = 8, index_t SmemPack = 8>
|
||||
@@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensorTmp>
|
||||
__device__ void operator() (CBlockTensor& c_block_tensor,
|
||||
__device__ void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensorTmp& a_block_tensor_tmp,
|
||||
const BLdsTile& b_block_tensor_tmp) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensorTmp::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BLdsTile::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
|
||||
constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}];
|
||||
|
||||
@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr (kM0 == 64)
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 32)
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 128)
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy
|
||||
template <typename Problem, index_t kM0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
if constexpr (kM0 == 64)
|
||||
if constexpr(kM0 == 64)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 32)
|
||||
else if constexpr(kM0 == 32)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1);
|
||||
}
|
||||
else if constexpr (kM0 == 128)
|
||||
else if constexpr(kM0 == 128)
|
||||
{
|
||||
return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1);
|
||||
}
|
||||
|
||||
@@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
// B LDS tile for block GEMM
|
||||
auto b_lds_gemm_window = make_tile_window(
|
||||
b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0},
|
||||
b_lds_block,
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode()));
|
||||
|
||||
// Acc register tile
|
||||
@@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence<kMPerBlock, kKPerBlock>{}),
|
||||
b_lds_gemm_window)){};
|
||||
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
#if !defined(TOY_FA_FWD_OPT)
|
||||
@@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
store_tile(b_copy_lds_window, b_block_tile);
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
b_copy_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
@@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
move_tile_window(b_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
block_gemm(c_block_tile,
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, i_k0 * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (i_k0 + 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
@@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 2) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, (k_loops - 1) * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
bWarpTile);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
@@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
get_slice_tile(a_copy_reg_tensor,
|
||||
sequence<0, (k_loops - 1) * kKPerBlock>{},
|
||||
sequence<kMPerBlock, k_loops * kKPerBlock>{}),
|
||||
bWarpTile);
|
||||
bWarpTile);
|
||||
}
|
||||
#endif
|
||||
return c_block_tile;
|
||||
|
||||
@@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
return BlockGemmARegBSmemCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeARegBlockDescriptor()
|
||||
{
|
||||
@@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
return a_block_dstr;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
return MakeARegBlockDescriptor<Problem>();
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
@@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy
|
||||
|
||||
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<NLdsLayer>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kNPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(
|
||||
make_tuple(number<kNPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform(make_tuple(number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
|
||||
template <typename Problem>
|
||||
__host__ __device__ static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
|
||||
@@ -29,13 +29,13 @@ int main(int argc, char* argv[])
|
||||
using OaccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
ck_tile::index_t Batch = 64; // Batch Number * Head Number
|
||||
ck_tile::index_t M0 = 4096; // SequenceLengthQ
|
||||
ck_tile::index_t N0 = 4096; // SequencelengthK
|
||||
ck_tile::index_t K0 = 128; // HeadDim
|
||||
ck_tile::index_t N1 = 128; // HeadDim
|
||||
ck_tile::index_t verification = 0;
|
||||
ck_tile::index_t init_method = 1;
|
||||
|
||||
if(argc == 3)
|
||||
{
|
||||
|
||||
@@ -155,7 +155,6 @@ struct FlashAttentionFwd
|
||||
const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) %
|
||||
num_tile_n1 * kN1PerBlock);
|
||||
|
||||
|
||||
const auto kernel_impl = FlashAttentionFwdImpl<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
#include "block_gemm_areg_bsmem_creg_v1.hpp"
|
||||
#include "tile_gemm_shape.hpp"
|
||||
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
@@ -196,14 +195,16 @@ struct FlashAttentionFwdImpl
|
||||
// V LDS tile window for store
|
||||
auto v_copy_lds_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
v_dram_window.get_tile_distribution());
|
||||
|
||||
// V LDS tile for block GEMM
|
||||
auto v_lds_gemm_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
auto v_lds_gemm_window =
|
||||
make_tile_window(v_lds,
|
||||
make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}),
|
||||
{0, 0},
|
||||
make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode()));
|
||||
#else
|
||||
auto v_lds_window = make_tile_window(
|
||||
v_lds, make_tuple(number<kN1PerBlock>{}, number<kK1PerBlock>{}), {0, 0});
|
||||
@@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl
|
||||
store_tile(v_lds_window, v);
|
||||
block_sync_lds();
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
v_lds_window);
|
||||
block_sync_lds();
|
||||
});
|
||||
#else
|
||||
@@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl
|
||||
move_tile_window(v_dram_window, {0, kK1PerBlock});
|
||||
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
get_slice_tile(p,
|
||||
sequence<0, i_k1 * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (i_k1 + 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1.template HotLoopScheduler<8, 4>();
|
||||
@@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl
|
||||
}
|
||||
// tail
|
||||
{
|
||||
if constexpr (k1_loops > 1)
|
||||
if constexpr(k1_loops > 1)
|
||||
{
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 2) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, (k1_loops - 1) * kK1PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
store_tile(v_copy_lds_window, v_prefetch);
|
||||
block_sync_lds();
|
||||
vWarpTile = load_tile(v_lds_gemm_window);
|
||||
gemm1(o_acc,
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
get_slice_tile(p,
|
||||
sequence<0, (k1_loops - 1) * kK1PerBlock>{},
|
||||
sequence<kM0PerBlock, kN0PerBlock>{}),
|
||||
vWarpTile);
|
||||
block_sync_lds();
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -1,33 +1,18 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#SPDX - License - Identifier : MIT
|
||||
#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved.
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass
|
||||
|
||||
def get_if_str(size_, total, last_else=True):
|
||||
if size_ == "head_dim_256_seq_4096":
|
||||
return 'if'
|
||||
else:
|
||||
return 'else if'
|
||||
def get_if_str(size_, total, last_else = True) : if size_ == "head_dim_256_seq_4096" : return 'if' else : return 'else if'
|
||||
|
||||
DATA_TYPE_MAP = {'fp32': 'float',
|
||||
'fp16': 'ck_tile::half_t',
|
||||
'bf16': 'ck_tile::bf16_t'}
|
||||
DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::half_t', 'bf16' : 'ck_tile::bf16_t' }
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
return 'true' if b_ else 'false'
|
||||
def BOOL_MAP(b_)->str: return 'true' if b_ else 'false'
|
||||
|
||||
class FlashAttentionFwdCodegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
class FlashAttentionFwdCodegen:API_TRAITS_DEFINE = ""
|
||||
"
|
||||
|
||||
template <typename SaccDataType_,
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
@@ -37,166 +22,184 @@ template <typename SaccDataType_,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_{using SaccDataType = ck_tile::remove_cvref_t <SaccDataType_>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kHeadDim = kHeadDim_;
|
||||
static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kHeadDim = kHeadDim_;
|
||||
static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
};
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size();
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
}
|
||||
;
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize = 256,
|
||||
ck_tile::index_t kHeadDim = 128,
|
||||
ck_tile::index_t kBlockSize = 256,
|
||||
ck_tile::index_t kHeadDim = 128,
|
||||
ck_tile::index_t kM0PerBlock = 128,
|
||||
ck_tile::index_t kN0PerBlock = 128,
|
||||
ck_tile::index_t kK0PerBlock = 64,
|
||||
ck_tile::index_t kN1PerBlock = 128,
|
||||
ck_tile::index_t kK1PerBlock = 64>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>;
|
||||
"""
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>;
|
||||
""
|
||||
"
|
||||
|
||||
API_BASE = """
|
||||
API_BASE = ""
|
||||
"
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
namespace ck_tile
|
||||
{
|
||||
{
|
||||
|
||||
{F_traits_define}
|
||||
{
|
||||
F_traits_define
|
||||
}
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(
|
||||
const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config);
|
||||
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType>
|
||||
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config) {{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
|
||||
template float flash_attention_fwd<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, float, float, ck_tile::half_t, float, ck_tile::half_t>(
|
||||
const FlashAttnArgs<ck_tile::half_t, ck_tile::half_t, ck_tile::half_t, ck_tile::half_t>&,
|
||||
const ck_tile::stream_config&);
|
||||
|
||||
}}
|
||||
"""
|
||||
|
||||
API_INNER_CASE = """ {F_if} {F_VEC_COND}
|
||||
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(a, stream_config);
|
||||
"""
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "flash_attention_fwd_api_common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// clang-format off
|
||||
//
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType>
|
||||
float flash_attention_fwd(
|
||||
const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config)
|
||||
{
|
||||
{
|
||||
float r = -1;
|
||||
{
|
||||
F_dispatch
|
||||
}
|
||||
return r;
|
||||
}
|
||||
}
|
||||
|
||||
template float flash_attention_fwd<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
float,
|
||||
ck_tile::half_t,
|
||||
float,
|
||||
ck_tile::half_t>(const FlashAttnArgs<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::half_t>&,
|
||||
const ck_tile::stream_config&);
|
||||
}
|
||||
}
|
||||
"""
|
||||
""
|
||||
"
|
||||
|
||||
def __init__(self, working_path, kernel_filter):
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
API_INNER_CASE = ""
|
||||
" {F_if} {F_VEC_COND}
|
||||
r = flash_attention_fwd_<QDataType, KDataType, VDataType, ODataType, traits_<{F_trait_name}>>(
|
||||
a, stream_config);
|
||||
""
|
||||
"
|
||||
|
||||
@dataclass
|
||||
class h_traits:
|
||||
F_SaccDataType: str
|
||||
F_SMPLComputeDataType: str
|
||||
F_PDataType: str
|
||||
F_OaccDataType: str
|
||||
F_kBlockSize: int
|
||||
F_kHeadDim: int
|
||||
F_kM0PerBlock: int
|
||||
F_kN0PerBlock: int
|
||||
F_kK0PerBlock: int
|
||||
F_kN1PerBlock: int
|
||||
F_kK1PerBlock: int
|
||||
|
||||
@property
|
||||
def trait_name(self) -> str:
|
||||
return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, "
|
||||
f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, "
|
||||
f"{DATA_TYPE_MAP[self.F_PDataType]}, "
|
||||
f"{DATA_TYPE_MAP[self.F_OaccDataType]}, "
|
||||
f"{self.F_kBlockSize}, {self.F_kHeadDim}, "
|
||||
f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, "
|
||||
f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}")
|
||||
|
||||
@property
|
||||
def def_name(self) -> str:
|
||||
return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, "
|
||||
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, "
|
||||
f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, "
|
||||
f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, "
|
||||
"const ck_tile::stream_config&);")
|
||||
|
||||
@dataclass
|
||||
class h_instance:
|
||||
F_DataTypePair: str # "q,k,v,o"
|
||||
F_SizeCategory: str # "small", "medium", "large"
|
||||
instance_list: List[Any] # List[h_traits]
|
||||
|
||||
INSTANCE_BASE = """
|
||||
INSTANCE_BASE = ""
|
||||
"
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "flash_attention_fwd_api_common.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
// clang-format off
|
||||
namespace ck_tile
|
||||
{
|
||||
// clang-format off
|
||||
//
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
}}
|
||||
// clang-format on
|
||||
}
|
||||
""
|
||||
"
|
||||
|
||||
def
|
||||
__init__(self, working_path, kernel_filter)
|
||||
: self.working_path = working_path self.kernel_filter = kernel_filter
|
||||
|
||||
@dataclass class h_traits
|
||||
: F_SaccDataType : str F_SMPLComputeDataType : str F_PDataType : str F_OaccDataType
|
||||
: str F_kBlockSize : int F_kHeadDim : int F_kM0PerBlock : int F_kN0PerBlock : int F_kK0PerBlock
|
||||
: int F_kN1PerBlock : int F_kK1PerBlock : int
|
||||
|
||||
@property def trait_name(self)
|
||||
->str
|
||||
: return (f "{DATA_TYPE_MAP[self.F_SaccDataType]}, " f
|
||||
"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " f
|
||||
"{DATA_TYPE_MAP[self.F_PDataType]}, " f "{DATA_TYPE_MAP[self.F_OaccDataType]}, " f
|
||||
"{self.F_kBlockSize}, {self.F_kHeadDim}, " f
|
||||
"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " f
|
||||
"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}")
|
||||
|
||||
@property def def_name(self)
|
||||
->str
|
||||
: return (f "template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " f
|
||||
"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " f
|
||||
"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " f
|
||||
"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, "
|
||||
"const ck_tile::stream_config&);")
|
||||
|
||||
@dataclass class h_instance : F_DataTypePair : str #"q,k,v,o" F_SizeCategory : str
|
||||
#"small",
|
||||
"medium",
|
||||
"large" instance_list : List[Any] #List[h_traits]
|
||||
|
||||
INSTANCE_BASE = ""
|
||||
"
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "flash_attention_fwd_api_common.hpp"
|
||||
|
||||
namespace ck_tile
|
||||
{
|
||||
{
|
||||
// clang-format off
|
||||
//
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
}}
|
||||
"""
|
||||
|
||||
@property
|
||||
@@ -226,123 +229,133 @@ namespace ck_tile {{
|
||||
|
||||
#include "flash_attention_fwd.hpp"
|
||||
|
||||
namespace ck_tile {{
|
||||
namespace ck_tile
|
||||
{
|
||||
{
|
||||
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
template <typename SaccDataType_,
|
||||
typename SMPLComputeDataType_,
|
||||
typename PDataType_,
|
||||
typename OaccDataType_,
|
||||
index_t kBlockSize_ = 256,
|
||||
index_t kHeadDim_ = 128,
|
||||
index_t kM0PerBlock_ = 128,
|
||||
index_t kN0PerBlock_ = 128,
|
||||
index_t kK0PerBlock_ = 64,
|
||||
index_t kN1PerBlock_ = 128,
|
||||
index_t kK1PerBlock_ = 64>
|
||||
struct flash_attention_fwd_traits_
|
||||
{
|
||||
{
|
||||
using SaccDataType = ck_tile::remove_cvref_t<SaccDataType_>;
|
||||
using SMPLComputeDataType = ck_tile::remove_cvref_t<SMPLComputeDataType_>;
|
||||
using PDataType = ck_tile::remove_cvref_t<PDataType_>;
|
||||
using OaccDataType = ck_tile::remove_cvref_t<OaccDataType_>;
|
||||
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kHeadDim = kHeadDim_;
|
||||
static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kHeadDim = kHeadDim_;
|
||||
static constexpr index_t kM0PerBlock = kM0PerBlock_;
|
||||
static constexpr index_t kN0PerBlock = kN0PerBlock_;
|
||||
static constexpr index_t kK0PerBlock = kK0PerBlock_;
|
||||
static constexpr index_t kN1PerBlock = kN1PerBlock_;
|
||||
static constexpr index_t kK1PerBlock = kK1PerBlock_;
|
||||
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
}};
|
||||
static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize,
|
||||
ck_tile::index_t kHeadDim,
|
||||
ck_tile::index_t kM0PerBlock,
|
||||
ck_tile::index_t kN0PerBlock,
|
||||
ck_tile::index_t kK0PerBlock,
|
||||
ck_tile::index_t kN1PerBlock,
|
||||
ck_tile::index_t kK1PerBlock>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>;
|
||||
|
||||
template <typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
ck_tile::index_t kBlockSize,
|
||||
ck_tile::index_t kHeadDim,
|
||||
ck_tile::index_t kM0PerBlock,
|
||||
ck_tile::index_t kN0PerBlock,
|
||||
ck_tile::index_t kK0PerBlock,
|
||||
ck_tile::index_t kN1PerBlock,
|
||||
ck_tile::index_t kK1PerBlock>
|
||||
using traits_ = flash_attention_fwd_traits_<SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>;
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(
|
||||
const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config)
|
||||
{
|
||||
{
|
||||
using SaccDataType = typename Traits_::SaccDataType;
|
||||
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
using PDataType = typename Traits_::PDataType;
|
||||
using OaccDataType = typename Traits_::OaccDataType;
|
||||
|
||||
index_t kGridSize =
|
||||
a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType,
|
||||
typename Traits_>
|
||||
float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config) {{
|
||||
using SaccDataType = typename Traits_::SaccDataType;
|
||||
using SMPLComputeDataType = typename Traits_::SMPLComputeDataType;
|
||||
using PDataType = typename Traits_::PDataType;
|
||||
using OaccDataType = typename Traits_::OaccDataType;
|
||||
if(stream_config.log_level_ > 0)
|
||||
std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << ","
|
||||
<< Traits_::kHeadDim << ">" << std::flush;
|
||||
|
||||
index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock);
|
||||
|
||||
if(stream_config.log_level_ > 0)
|
||||
std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(stream_config,
|
||||
ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
|
||||
ck_tile::FlashAttentionFwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
Traits_::kBlockSize,
|
||||
Traits_::kHeadDim,
|
||||
Traits_::kM0PerBlock,
|
||||
Traits_::kN0PerBlock,
|
||||
Traits_::kK0PerBlock,
|
||||
Traits_::kN1PerBlock,
|
||||
Traits_::kK1PerBlock>{{}},
|
||||
kGridSize,
|
||||
Traits_::kBlockSize,
|
||||
0,
|
||||
a.q_ptr,
|
||||
a.k_ptr,
|
||||
a.v_ptr,
|
||||
a.o_ptr,
|
||||
a.M0,
|
||||
a.N0,
|
||||
a.K0,
|
||||
a.N1,
|
||||
a.Batch,
|
||||
a.strideQ, // StrideQ
|
||||
a.strideK, // StrideK
|
||||
a.strideV, // StrideV
|
||||
a.strideO, // StrideO
|
||||
a.batchStrideQ, // BatchStrideQ
|
||||
a.batchStrideK, // BatchStrideK
|
||||
a.batchStrideV, // BatchStrideV
|
||||
a.batchStrideO)); // BatchStrideO
|
||||
}}
|
||||
}}
|
||||
return ck_tile::launch_kernel(
|
||||
stream_config,
|
||||
ck_tile::make_kernel<Traits_::kBlockSize, Traits_::kBlockPerCu>(
|
||||
ck_tile::FlashAttentionFwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
Traits_::kBlockSize,
|
||||
Traits_::kHeadDim,
|
||||
Traits_::kM0PerBlock,
|
||||
Traits_::kN0PerBlock,
|
||||
Traits_::kK0PerBlock,
|
||||
Traits_::kN1PerBlock,
|
||||
Traits_::kK1PerBlock>{{}},
|
||||
kGridSize,
|
||||
Traits_::kBlockSize,
|
||||
0,
|
||||
a.q_ptr,
|
||||
a.k_ptr,
|
||||
a.v_ptr,
|
||||
a.o_ptr,
|
||||
a.M0,
|
||||
a.N0,
|
||||
a.K0,
|
||||
a.N1,
|
||||
a.Batch,
|
||||
a.strideQ, // StrideQ
|
||||
a.strideK, // StrideK
|
||||
a.strideV, // StrideV
|
||||
a.strideO, // StrideO
|
||||
a.batchStrideQ, // BatchStrideQ
|
||||
a.batchStrideK, // BatchStrideK
|
||||
a.batchStrideV, // BatchStrideV
|
||||
a.batchStrideO)); // BatchStrideO
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
def content_api(self, args) -> str:
|
||||
# Sort based on dtype
|
||||
#Sort based on dtype
|
||||
t_dtype_dict = {}
|
||||
blobs = self.get_blobs(args)
|
||||
|
||||
@@ -402,7 +415,7 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
h_traits = self.h_traits
|
||||
h_instance = self.h_instance
|
||||
|
||||
# Define kernel configurations for different size categories
|
||||
#Define kernel configurations for different size categories
|
||||
trait_dict = {
|
||||
"head_dim_256_seq_4096": [
|
||||
h_traits('fp32', 'fp32', 'fp32', 'fp32', 256, 256, 128, 128, 64, 128, 64),
|
||||
@@ -424,17 +437,17 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
],
|
||||
}
|
||||
|
||||
# Toy example only support fp16
|
||||
#Toy example only support fp16
|
||||
dtype_combinations = [
|
||||
"fp16,fp16,fp16,fp16"
|
||||
# "bf16,bf16,bf16,bf16"
|
||||
#"bf16,bf16,bf16,bf16"
|
||||
]
|
||||
|
||||
total_blob = []
|
||||
for dtype_pair in dtype_combinations:
|
||||
for size_category in trait_dict:
|
||||
traits = trait_dict[size_category]
|
||||
# Convert data types for the current dtype_pair
|
||||
#Convert data types for the current dtype_pair
|
||||
q_type, k_type, v_type, o_type = dtype_pair.split(',')
|
||||
current_traits = []
|
||||
for t in traits:
|
||||
@@ -455,10 +468,10 @@ float flash_attention_fwd_(const FlashAttnArgs<QDataType, KDataType, VDataType,
|
||||
blobs = self.get_blobs(args)
|
||||
|
||||
with list_p.open('w') as list_f:
|
||||
# API related files
|
||||
#API related files
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
||||
# Kernel instance files
|
||||
#Kernel instance files
|
||||
for b in blobs:
|
||||
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user