mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
chore(gemm): clang format to pass CI (#2758)
This commit is contained in:
@@ -29,7 +29,6 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
@@ -59,7 +58,7 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
|
||||
BLayout,
|
||||
CLayout>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
@@ -74,7 +73,6 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = false;
|
||||
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
@@ -144,7 +142,8 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);;
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
;
|
||||
}
|
||||
|
||||
#include "run_gemm_bquant_example.inc"
|
||||
|
||||
@@ -83,10 +83,10 @@ struct GemmConfigBase
|
||||
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -104,7 +104,7 @@ struct GemmConfigDecode : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
@@ -146,8 +146,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -96,13 +96,13 @@ struct BQuantGemmKernelArgs
|
||||
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
|
||||
struct BQuantGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
|
||||
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
|
||||
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
|
||||
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
|
||||
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
|
||||
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
|
||||
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
|
||||
|
||||
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
|
||||
|
||||
@@ -44,12 +44,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
|
||||
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
|
||||
@@ -72,12 +72,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
|
||||
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
|
||||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
|
||||
static_assert(std::is_same_v<typename Problem::CDataType, float>);
|
||||
|
||||
@@ -32,7 +32,7 @@ struct TileGemmAQuantTraits
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
|
||||
Reference in New Issue
Block a user