chore(gemm): clang format to pass CI (#2758)

This commit is contained in:
Aviral Goel
2025-08-29 03:38:46 -04:00
committed by GitHub
parent 4208e28988
commit fcff0043ae
6 changed files with 30 additions and 32 deletions

View File

@@ -29,7 +29,6 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
constexpr bool kPadN = false;
constexpr bool kPadK = false;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
@@ -59,7 +58,7 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
BLayout,
CLayout>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
@@ -74,7 +73,6 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
@@ -144,7 +142,8 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);;
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
;
}
#include "run_gemm_bquant_example.inc"

View File

@@ -83,10 +83,10 @@ struct GemmConfigBase
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
};
template <typename PrecType>
@@ -104,7 +104,7 @@ struct GemmConfigDecode : public GemmConfigBase
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
@@ -146,8 +146,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 1;
static constexpr int kBlockPerCu = 1;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;

View File

@@ -13,8 +13,8 @@
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -96,13 +96,13 @@ struct BQuantGemmKernelArgs
template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
struct BQuantGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
using GemmPipeline = remove_cvref_t<GemmPipeline_>;
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>;
using ALayout = remove_cvref_t<typename GemmPipeline::ALayout>;
using BLayout = remove_cvref_t<typename GemmPipeline::BLayout>;
using BQLayout = remove_cvref_t<typename GemmPipeline::BQLayout>;
using CLayout = remove_cvref_t<typename GemmPipeline::CLayout>;
static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;

View File

@@ -44,12 +44,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
constexpr index_t VecLoadSize = GetVectorSizeBQ<Problem>();
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using TileEncodingPattern = TileDistributionEncodingPatternBQ<BlockGemmShape,
@@ -72,12 +72,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC
"KPerWarpGemm must be a multiple of kQuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);

View File

@@ -32,7 +32,7 @@ struct TileGemmAQuantTraits
static constexpr bool UseStructuredSparsity = false;
static constexpr index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
static constexpr bool PreshuffleQuant = PreshuffleQuant_;
};
template <bool kPadM_,