mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
upgrade from clang-format-12 to clang-format-18 (#2568)
* upgrade to clang-format-18 * update to clang-format-18 in pre-commit-config
This commit is contained in:
@@ -284,8 +284,8 @@ struct CShuffleEpilogue
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
sequence<0, 1>,
|
||||
sequence<MPerIterationShuffle, NPerIterationShuffle>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
|
||||
@@ -336,8 +336,8 @@ struct CShuffleEpilogue
|
||||
|
||||
const auto c_ds_tiles = concat_tuple_of_reference(
|
||||
tie(c_out_tensor, c_out_tensor),
|
||||
generate_tie(
|
||||
[&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
|
||||
generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
|
||||
number<NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
|
||||
|
||||
|
||||
@@ -458,7 +458,7 @@ struct FlatmmPipelineAGmemBGmemCRegV1 : public BaseFlatmmPipelineAGmemBGmemCRegV
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
|
||||
@@ -431,12 +431,12 @@ struct UniversalFlatmmPipelineAgBgCrPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockFlatmmPolicy = BlockFlatmmASmemBSmemCRegV1CustomPolicy<
|
||||
typename Problem::ADataType,
|
||||
|
||||
@@ -509,7 +509,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
|
||||
template <typename Problem, index_t IBuf = 0>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
|
||||
MakeKLdsStoreBlockDescriptor(number<IBuf> = number<0>{})
|
||||
{
|
||||
// K is always k-major, we use async-copy to load into LDS
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
@@ -60,8 +60,8 @@ struct TileFmhaShape
|
||||
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
|
||||
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;
|
||||
using VLayout = std::conditional_t<IsVLayoutRowMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor>;
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
|
||||
@@ -385,7 +385,7 @@ struct FusedMoeGemmKernel
|
||||
auto o_window = [&]() {
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
|
||||
auto o_view_ = make_naive_tensor_view<address_space_enum::global,
|
||||
memory_operation_enum::atomic_add>(
|
||||
memory_operation_enum::atomic_add>(
|
||||
o_ptr,
|
||||
make_tuple(kargs.num_tokens, kargs.hidden_size),
|
||||
make_tuple(kargs.stride_token, 1),
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
namespace ck_tile {
|
||||
|
||||
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
|
||||
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
|
||||
static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
|
||||
|
||||
#ifndef MOE_SORTING_USE_EX_KERNEL
|
||||
#define MOE_SORTING_USE_EX_KERNEL 1
|
||||
|
||||
@@ -267,8 +267,7 @@ struct FusedMoeGemmPipeline_FlatmmEx
|
||||
statically_indexed_array<a_thread_type, 2> as;
|
||||
|
||||
auto gld_a = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& a_store_, auto i_access, PreNop = {})
|
||||
{
|
||||
auto& a_store_, auto i_access, PreNop = {}) {
|
||||
async_load_tile_raw(a_store_, a_win, i_access, PreNop{});
|
||||
};
|
||||
auto move_a = [&]() {
|
||||
@@ -278,43 +277,40 @@ struct FusedMoeGemmPipeline_FlatmmEx
|
||||
load_tile_raw(a_, win_, i_access);
|
||||
};
|
||||
|
||||
auto gld_g = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& g_, auto i_access, PreNop = {})
|
||||
{
|
||||
if constexpr(IsGateOnly)
|
||||
{
|
||||
// TODO: hack!
|
||||
if constexpr(i_access.value == 0)
|
||||
auto gld_g =
|
||||
[&]<typename PreNop = bool_constant<false>>(auto& g_, auto i_access, PreNop = {}) {
|
||||
if constexpr(IsGateOnly)
|
||||
{
|
||||
g_win.bottom_tensor_view_ = g_view;
|
||||
// TODO: hack!
|
||||
if constexpr(i_access.value == 0)
|
||||
{
|
||||
g_win.bottom_tensor_view_ = g_view;
|
||||
}
|
||||
else if constexpr(i_access.value == issues_g / 2)
|
||||
{
|
||||
g_win.bottom_tensor_view_ = u_view;
|
||||
}
|
||||
}
|
||||
else if constexpr(i_access.value == issues_g / 2)
|
||||
{
|
||||
g_win.bottom_tensor_view_ = u_view;
|
||||
}
|
||||
}
|
||||
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
load_tile_raw(g_, g_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
auto move_g = [&]() {
|
||||
move_tile_window(g_win, {number<0>{}, number<BlockShape::Block_Kr0>{}, number<0>{}});
|
||||
};
|
||||
statically_indexed_array<d_thread_type, 2> ds;
|
||||
|
||||
auto gld_d = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& d_, auto i_access, PreNop = {})
|
||||
{
|
||||
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
auto gld_d =
|
||||
[&]<typename PreNop = bool_constant<false>>(auto& d_, auto i_access, PreNop = {}) {
|
||||
load_tile_raw(d_, d_win, i_access, FALSE, PreNop{});
|
||||
};
|
||||
auto move_d = [&]() {
|
||||
// d move along gemm-n
|
||||
move_tile_window(d_win, {number<BlockShape::Block_N1>{}, number<0>{}});
|
||||
};
|
||||
|
||||
auto atomic_add_o = [&]<typename PreNop = bool_constant<false>>(
|
||||
auto& o_, auto i_access, PreNop = {})
|
||||
{
|
||||
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
|
||||
};
|
||||
auto atomic_add_o =
|
||||
[&]<typename PreNop = bool_constant<false>>(auto& o_, auto i_access, PreNop = {}) {
|
||||
update_tile_raw(o_win, o_, i_access, TRUE, PreNop{});
|
||||
};
|
||||
|
||||
auto acc_0 = Policy::template MakeCBlockTile_Gemm0<Problem>();
|
||||
auto acc_1s = generate_tuple(
|
||||
|
||||
@@ -69,8 +69,8 @@ struct GemmTile2DPartitioner
|
||||
* @param blockIdy WGP's Y index.
|
||||
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
CK_TILE_DEVICE static auto
|
||||
GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple<index_t, index_t>
|
||||
{
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
|
||||
const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
|
||||
@@ -137,8 +137,8 @@ struct GemmTile1DPartitioner
|
||||
* @param blockIdx WGP's index.
|
||||
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
CK_TILE_DEVICE static auto
|
||||
GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple<index_t, index_t>
|
||||
{
|
||||
const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
|
||||
|
||||
@@ -188,9 +188,8 @@ struct OffsettedTile1DPartitioner
|
||||
* @param [in] N Gemm's N dimension.
|
||||
* @return Returns a `tuple` [Im, In] with shifted index.
|
||||
*/
|
||||
[[nodiscard]] CK_TILE_DEVICE static auto
|
||||
GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
[[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex(
|
||||
index_t block_start, index_t M, index_t N) noexcept -> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
|
||||
return make_tuple(iM, iN);
|
||||
@@ -271,8 +270,8 @@ struct GemmSpatiallyLocalTilePartitioner
|
||||
* @param [in] block_1d_id WGP's index.
|
||||
* @return const tuple<index_t, index_t> Tuple containing 2D output C-tile index.
|
||||
*/
|
||||
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
|
||||
-> const tuple<index_t, index_t>
|
||||
CK_TILE_DEVICE auto
|
||||
GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple<index_t, index_t>
|
||||
{
|
||||
const auto M0 = integer_divide_ceil(M, MPerBlock);
|
||||
const auto N0 = integer_divide_ceil(N, NPerBlock);
|
||||
|
||||
@@ -144,8 +144,8 @@ struct GroupedGemmKernel
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::size_t
|
||||
CK_TILE_HOST static auto
|
||||
GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::size_t
|
||||
{
|
||||
return gemm_descs.size() * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
@@ -185,8 +185,8 @@ struct GroupedGemmKernel
|
||||
return dim3(grid_size, 1, 1);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static auto MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs)
|
||||
-> std::vector<GemmTransKernelArg>
|
||||
CK_TILE_HOST static auto
|
||||
MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::vector<GemmTransKernelArg>
|
||||
{
|
||||
std::vector<GemmTransKernelArg> gemm_kernel_args_;
|
||||
index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
|
||||
|
||||
@@ -28,20 +28,20 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
(DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType)) ==
|
||||
(WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size());
|
||||
constexpr auto wg_attr_num_access =
|
||||
((is_a_load_tr<Problem> || is_b_load_tr<Problem>)&&!single_load_tr_length)
|
||||
((is_a_load_tr<Problem> || is_b_load_tr<Problem>) && !single_load_tr_length)
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -24,12 +24,12 @@ struct GemmPipelineAgBgCrCompV5DefaultPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -284,9 +284,9 @@ struct GemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](const BDataType & b) { return b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
@@ -394,12 +394,12 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -241,9 +241,9 @@ struct GemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](const BDataType& b) { return b; },
|
||||
[](const BDataType & b) { return b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
@@ -169,10 +169,10 @@ struct UniversalGemmBasePolicy
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
BTileAccessPattern>;
|
||||
|
||||
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
|
||||
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
|
||||
@@ -636,15 +636,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
|
||||
@@ -462,7 +462,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
[](const ADataType & a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
|
||||
@@ -430,12 +430,12 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
|
||||
@@ -142,22 +142,15 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase<Problem_>
|
||||
// 2. bf8, fp32, bf8 -> f32
|
||||
// 3. i4, (fp8/fp32) fp8 -> f32
|
||||
// 4. i4, (fp8/fp32) bf8 -> f32
|
||||
static_assert(
|
||||
(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<
|
||||
ADataType,
|
||||
bf8_t>)&&(std::is_same_v<BDataType, fp8_t> ||
|
||||
std::is_same_v<
|
||||
BDataType,
|
||||
bf8_t>)&&(std::is_same_v<AQDataType, float> ||
|
||||
std::is_same_v<AQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<
|
||||
AQDataType,
|
||||
ck_tile::bf8_t>)&&(std::is_same_v<ComputeDataType,
|
||||
fp8_t> ||
|
||||
std::is_same_v<ComputeDataType,
|
||||
bf8_t>)&&std::
|
||||
is_same_v<CDataType, fp32_t>);
|
||||
static_assert((std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
|
||||
std::is_same_v<ADataType, bf8_t>) &&
|
||||
(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>) &&
|
||||
(std::is_same_v<AQDataType, float> ||
|
||||
std::is_same_v<AQDataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
|
||||
(std::is_same_v<ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<ComputeDataType, bf8_t>) &&
|
||||
std::is_same_v<CDataType, fp32_t>);
|
||||
|
||||
static constexpr index_t InterWaveSchedulingMacClusters = 1;
|
||||
|
||||
|
||||
@@ -44,12 +44,12 @@ struct GemmAQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t VecLoadSize = GetVectorSizeAQ<Problem>();
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
false>;
|
||||
|
||||
static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternAQ<BlockGemmShape,
|
||||
|
||||
@@ -202,8 +202,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseAQuantGemmPipelineAgBgCrCompV
|
||||
<< "AQ vector size: " << GetVectorSizeAQ() << "\n"
|
||||
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
|
||||
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
|
||||
<< ", "
|
||||
<< "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
|
||||
<< ", " << "AQ buffer load inst: " << AQ_Buffer_Load_Inst_Num << "\n"
|
||||
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
|
||||
@@ -37,13 +37,13 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -106,15 +106,15 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -122,13 +122,13 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -182,17 +182,17 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -201,17 +201,17 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -254,8 +254,9 @@ struct GroupedConvBwdWeightKernelArgs
|
||||
GemmBatch = args.G_;
|
||||
}
|
||||
|
||||
using ABCGridDescs = remove_cvref_t<decltype(
|
||||
ConvToGemmTransformer{}.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
|
||||
using ABCGridDescs =
|
||||
remove_cvref_t<decltype(ConvToGemmTransformer{}
|
||||
.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())>;
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(ABCGridDescs{}[number<0>{}])>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(ABCGridDescs{}[number<1>{}])>;
|
||||
|
||||
@@ -37,13 +37,13 @@ struct GroupedConvFwdKernelArgs
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -107,15 +107,15 @@ struct GroupedConvFwdKernelArgs
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -123,13 +123,13 @@ struct GroupedConvFwdKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[1])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
static_cast<index_t>(args.input_left_pads_[1])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
static_cast<index_t>(args.input_right_pads_[1])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -184,17 +184,17 @@ struct GroupedConvFwdKernelArgs
|
||||
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs& args)
|
||||
{
|
||||
in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.input_spatial_lengths_[2])};
|
||||
wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
static_cast<index_t>(args.K_),
|
||||
static_cast<index_t>(args.C_),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[0]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[1]),
|
||||
static_cast<index_t>(args.filter_spatial_lengths_[2])};
|
||||
out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
|
||||
static_cast<index_t>(args.N_),
|
||||
static_cast<index_t>(args.K_),
|
||||
@@ -203,17 +203,17 @@ struct GroupedConvFwdKernelArgs
|
||||
static_cast<index_t>(args.output_spatial_lengths_[2])};
|
||||
|
||||
conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
static_cast<index_t>(args.conv_filter_strides_[1]),
|
||||
static_cast<index_t>(args.conv_filter_strides_[2])};
|
||||
conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[1]),
|
||||
static_cast<index_t>(args.conv_filter_dilations_[2])};
|
||||
input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
static_cast<index_t>(args.input_left_pads_[1]),
|
||||
static_cast<index_t>(args.input_left_pads_[2])};
|
||||
input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
static_cast<index_t>(args.input_right_pads_[1]),
|
||||
static_cast<index_t>(args.input_right_pads_[2])};
|
||||
|
||||
k_batch = args.k_batch;
|
||||
|
||||
@@ -259,15 +259,15 @@ struct GroupedConvFwdKernelArgs
|
||||
group_stride_c = args.K_;
|
||||
}
|
||||
|
||||
using AGridDescMK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<decltype(
|
||||
ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>())>;
|
||||
using AGridDescMK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>())>;
|
||||
using BGridDescNK = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>())>;
|
||||
using CGridDescMN = remove_cvref_t<
|
||||
decltype(ConvToGemmFwdTransformer{}
|
||||
.template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>())>;
|
||||
|
||||
static constexpr index_t NonSpatialDims = 3;
|
||||
array<index_t, NonSpatialDims + GroupedConvTraitsType::NDimSpatial> in_g_n_c_wis_lengths;
|
||||
|
||||
@@ -67,11 +67,11 @@ struct GroupedConvTraits
|
||||
using DsLayout = DsLayout_;
|
||||
using OutLayout = OutLayout_;
|
||||
using GroupedConvImplicitGemmTraits = TileGemmTraits<true,
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
true,
|
||||
true,
|
||||
ck_tile::tensor_layout::gemm::RowMajor,
|
||||
ck_tile::tensor_layout::gemm::ColumnMajor,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>;
|
||||
static constexpr index_t NumDTensor = DsLayout::size();
|
||||
using ImplicitGemmDsLayout = decltype(generate_implicit_gemm_layout());
|
||||
};
|
||||
|
||||
@@ -380,6 +380,6 @@ struct BlockReduce2D
|
||||
|
||||
// deduction guide
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&) -> BlockReduce2D<T>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user