diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp index 991c4841e4..49e60bf86d 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_basic.cpp @@ -29,7 +29,6 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s constexpr bool kPadN = false; constexpr bool kPadK = false; - static_assert(std::is_same_v); constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; @@ -59,7 +58,7 @@ float gemm_calc_bquant(const ck_tile::BQuantGemmHostArgs& args, const ck_tile::s BLayout, CLayout>; - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase @@ -104,7 +104,7 @@ struct GemmConfigDecode : public GemmConfigBase static constexpr ck_tile::index_t N_Warp_Tile = 16; static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - static constexpr int kBlockPerCu = 1; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE; @@ -146,8 +146,7 @@ struct GemmConfigPreshuffleQuant : public GemmConfigBase static constexpr ck_tile::index_t K_Warp_Tile = get_k_from_preshuffled_warp_tile(); - - static constexpr int kBlockPerCu = 1; + static constexpr int kBlockPerCu = 1; static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT; diff --git a/include/ck_tile/ops/gemm_group_quant.hpp b/include/ck_tile/ops/gemm_group_quant.hpp index 752da6a616..92a53dd5ea 100644 --- a/include/ck_tile/ops/gemm_group_quant.hpp +++ b/include/ck_tile/ops/gemm_group_quant.hpp @@ -13,8 +13,8 @@ #include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp index 08b0ec0c2c..24e69d2628 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp +++ b/include/ck_tile/ops/gemm_group_quant/kernel/gemm_bquant_kernel.hpp @@ -96,13 +96,13 @@ struct BQuantGemmKernelArgs template struct BQuantGemmKernel { - using TilePartitioner = remove_cvref_t; - using GemmPipeline = remove_cvref_t; - using EpiloguePipeline = remove_cvref_t; - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using BQLayout = remove_cvref_t; - using CLayout = remove_cvref_t; + using TilePartitioner = remove_cvref_t; + using GemmPipeline = remove_cvref_t; + using EpiloguePipeline = remove_cvref_t; + using ALayout = remove_cvref_t; + using BLayout = remove_cvref_t; + using BQLayout = remove_cvref_t; + using CLayout = remove_cvref_t; static constexpr index_t kBlockSize = GemmPipeline::BlockSize; using ADataType = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp index a5ed83d24b..ff986d86fb 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp @@ -44,12 +44,12 @@ struct GemmBQuantPipelineAgBgCrDefaultPolicy : public UniversalGemmPipelineAgBgC constexpr index_t VecLoadSize = GetVectorSizeBQ(); using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpGemm = WarpGemmDispatcher; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; static_assert(std::is_same_v); using TileEncodingPattern = TileDistributionEncodingPatternBQ; + typename Problem::ComputeDataType, + typename Problem::CDataType, + WarpTile::at(I0), + WarpTile::at(I1), + WarpTile::at(I2), + Problem::TransposeC>; static_assert(std::is_same_v || std::is_same_v); static_assert(std::is_same_v); diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp index a156bb773d..05ce35ae59 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp @@ -32,7 +32,7 @@ struct TileGemmAQuantTraits static constexpr bool UseStructuredSparsity = false; static constexpr index_t NumWaveGroups = 1; - static constexpr bool PreshuffleQuant = PreshuffleQuant_; + static constexpr bool PreshuffleQuant = PreshuffleQuant_; }; template