bug fix, clang format;

This commit is contained in:
aska-0096
2025-08-08 09:04:02 +00:00
parent 3b9fb6af38
commit 78edd7303b
16 changed files with 180 additions and 80 deletions

View File

@@ -199,6 +199,7 @@
#else
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0

View File

@@ -1308,7 +1308,7 @@ struct Swish
struct SoftRelu
{
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
SoftRelu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1327,7 +1327,7 @@ struct SoftRelu
struct Power
{
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
: alpha_(alpha), beta_(beta), gamma_(gamma) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1349,7 +1349,7 @@ struct Power
struct ClippedRelu
{
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1368,7 +1368,7 @@ struct ClippedRelu
struct LeakyRelu
{
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
LeakyRelu(float alpha = 0.01f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1385,7 +1385,7 @@ struct LeakyRelu
struct Elu
{
Elu(float alpha = 1.f) : alpha_(alpha){};
Elu(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
@@ -1402,7 +1402,7 @@ struct Elu
struct Logistic
{
Logistic(float alpha = 1.f) : alpha_(alpha){};
Logistic(float alpha = 1.f) : alpha_(alpha) {};
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const

View File

@@ -193,7 +193,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
@@ -512,7 +512,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
typename Problem::BlockFmhaShape::Gemm0WarpTile>,
GemmLoopOrder::MNK>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::QDataType,
@@ -545,7 +546,8 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
Problem::BlockFmhaShape::kN1,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
typename Problem::BlockFmhaShape::Gemm1WarpTile>,
GemmLoopOrder::KMN>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::PDataType,
@@ -597,7 +599,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(

View File

@@ -40,6 +40,8 @@ struct BlockGemmARegBRegCRegV1
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr auto BlockGemmLoopOrder = Problem::BlockGemmLoopOrder;
static constexpr index_t KPack = WarpGemm::kKPerThread;
};
@@ -49,8 +51,9 @@ struct BlockGemmARegBRegCRegV1
using Traits = GemmTraits_<Problem, Policy>;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
using WarpGemm = typename Traits::WarpGemm;
using BlockGemmShape = typename Traits::BlockGemmShape;
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
@@ -83,17 +86,36 @@ struct BlockGemmARegBRegCRegV1
}
else
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<2, 1>,
sequence<0, 0>>{};
return a_block_dstr_encode;
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode;
}
}
}
@@ -115,17 +137,33 @@ struct BlockGemmARegBRegCRegV1
}
else
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<2, 1>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
return b_block_dstr_encode;
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode;
}
}
}
@@ -215,41 +253,82 @@ struct BlockGemmARegBRegCRegV1
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
{
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
});
}
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
{
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
}
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()

View File

@@ -7,13 +7,20 @@
namespace ck_tile {
enum struct GemmLoopOrder
{
KMN = 0,
MNK = 1,
};
// Problem Description for BlockGemm
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
index_t kBlockSize_,
typename BlockGemmShape_,
index_t NumWaveGroups_ = 1>
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN,
index_t NumWaveGroups_ = 1>
struct BlockGemmProblem
{
using ADataType = remove_cvref_t<ADataType_>;
@@ -21,8 +28,9 @@ struct BlockGemmProblem
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t NumWaveGroups = NumWaveGroups_;
static constexpr GemmLoopOrder BlockGemmLoopOrder = BlockGemmLoopOrder_;
};
} // namespace ck_tile

View File

@@ -104,6 +104,10 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
1>>;
#endif
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
@@ -210,6 +214,10 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
AttrNumAccess>>;
#endif
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
#if defined(__gfx950__)
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<