diff --git a/example/ck_tile/XX_moe_gemm/moe_gemm.hpp b/example/ck_tile/XX_moe_gemm/moe_gemm.hpp index 1f2355f756..2d66aa306a 100644 --- a/example/ck_tile/XX_moe_gemm/moe_gemm.hpp +++ b/example/ck_tile/XX_moe_gemm/moe_gemm.hpp @@ -3,6 +3,7 @@ #pragma once +#include #include #include "ck_tile/core.hpp" @@ -31,6 +32,23 @@ using CDataType = Types::CDataType; using moe_gemm_kargs = ck_tile::MoeGemmHostArgs; +template +struct MoeGemmHostTraits +{ + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + static constexpr ck_tile::index_t activation = activation_; + static constexpr bool IsGateOnly = gate_only_; + static constexpr bool IsFusedQuant = fused_quant_; +}; + + auto create_args(int argc, char* argv[]) { ck_tile::ArgParser arg_parser; @@ -48,6 +66,10 @@ auto create_args(int argc, char* argv[]) .insert("a_layout", "R", "A tensor data layout - Row by default.") .insert("b_layout", "C", "B tensor data layout - Col by default.") .insert("c_layout", "R", "C tensor data layout - Row by default.") + .insert("act", "0", "activation after first gemm. 0:gelu, 1:silu") + .insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant") + .insert( + "gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate") .insert("validate", "1", "0. No validation, 1. Validation on CPU.") .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") .insert("repeat", "10", "number of iterations to benchmark the kernel."); diff --git a/example/ck_tile/XX_moe_gemm/moe_gemm1_xdl_fp8.cpp b/example/ck_tile/XX_moe_gemm/moe_gemm1_xdl_fp8.cpp index e34aa6159f..bfb5545a34 100644 --- a/example/ck_tile/XX_moe_gemm/moe_gemm1_xdl_fp8.cpp +++ b/example/ck_tile/XX_moe_gemm/moe_gemm1_xdl_fp8.cpp @@ -38,7 +38,7 @@ struct MoeGemmKernelParam static const ck_tile::index_t K_Warp_Tile = 16; }; -template +template float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s) { using CodegenMoeGemmShape = ck_tile::TileFlatmmShape< @@ -54,19 +54,33 @@ float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s) using TilePartitioner = ck_tile::GemmTile1DPartitioner; - using CodegenMoeGemmTraits = ck_tile::TileGemmTraits; + true, + Traits::IsGateOnly, + Traits::IsFusedQuant, + typename Traits::ALayout, + typename Traits::BLayout, + typename Traits::CLayout, + decltype(get_activation_())>; using CodegenPipelineProblem = ck_tile::GemmPipelineProblem; + CodegenMoeGemmTraits, + AccDataType>; using CodegenMoeGemmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy; using CodegenMoeGemmPipeline = @@ -77,7 +91,7 @@ float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s) BDataType, AccDataType, CDataType, - CLayout, + typename Traits::CLayout, CodegenPipelineProblem::kBlockSize, TilePartitioner::MPerBlock, TilePartitioner::NPerBlock, diff --git a/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc b/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc index 6de0ae50be..a10d81020b 100644 --- a/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc +++ b/example/ck_tile/XX_moe_gemm/run_moe_gemm_example.inc @@ -70,10 +70,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_moe_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args) { - float ave_time = moe_gemm( + float ave_time = moe_gemm( args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::string op_name{"Moe Gemm"}; @@ -94,12 +94,9 @@ float invoke_moe_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args) return ave_time; } -template +template int run_moe_gemm_example_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const BLayout b_layout = BLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) + char* argv[]) { auto [result, arg_parser] = create_args(argc, argv); @@ -129,6 +126,11 @@ int run_moe_gemm_example_with_layouts(int argc, const ck_tile::index_t topk = arg_parser.get_int("TopK"); const ck_tile::index_t repeat = arg_parser.get_int("repeat"); const ck_tile::index_t experts = arg_parser.get_int("experts"); + const std::string mfma = arg_parser.get_str("prec"); + + auto a_layout = typename Traits::ALayout{}; + auto b_layout = typename Traits::BLayout{}; + auto c_layout = typename Traits::CLayout{}; // TODO: replace the magic declaration const ck_tile::index_t MPerBlock = 128; @@ -153,7 +155,7 @@ int run_moe_gemm_example_with_layouts(int argc, stride_A = ck_tile::get_default_stride(num_tokens, K, stride_A, is_row_major(a_layout)); stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(num_tokens * topk, N, stride_C, is_row_major(CLayout{})); + stride_C = ck_tile::get_default_stride(num_tokens * topk, N, stride_C, is_row_major(c_layout)); auto a_m_k_tensor = ck_tile::HostTensor( ck_tile::host_tensor_descriptor(num_tokens, K, stride_A, is_row_major(a_layout))); @@ -164,10 +166,9 @@ int run_moe_gemm_example_with_layouts(int argc, ? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout)) : ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout))); - std::string mfma = arg_parser.get_str("prec"); auto c_m_n_tensor = ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(num_tokens * topk, N, stride_C, is_row_major(CLayout{}))); + ck_tile::host_tensor_descriptor(num_tokens * topk, N, stride_C, is_row_major(c_layout))); ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensor); ck_tile::FillUniformDistribution{-1.f, 1.f}(b_k_n_tensor); @@ -268,7 +269,7 @@ int run_moe_gemm_example_with_layouts(int argc, stride_B, stride_C}; - invoke_moe_gemm(3, repeat, gemm_desc); + invoke_moe_gemm(3, repeat, gemm_desc); c_m_n_dev_buf->FromDevice(c_m_n_tensor.data()); @@ -276,7 +277,7 @@ int run_moe_gemm_example_with_layouts(int argc, if(arg_parser.get_int("validate")) { ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( - num_tokens * topk, N, stride_C, is_row_major(CLayout{}))); + num_tokens * topk, N, stride_C, is_row_major(c_layout))); c_m_n_host_ref.SetZero(); @@ -289,9 +290,9 @@ int run_moe_gemm_example_with_layouts(int argc, BDataType, AccDataType, CDataType, - ALayout, - BLayout, - CLayout>( + typename Traits::ALayout, + typename Traits::BLayout, + typename Traits::CLayout>( p_sorted_token_ids_dev, p_expert_ids_dev, p_max_token_id_dev, @@ -353,15 +354,20 @@ int run_moe_gemm_example(int argc, char* argv[]) return -1; } - const std::string a_layout = arg_parser.get_str("a_layout"); - const std::string b_layout = arg_parser.get_str("b_layout"); + const std::string a_layout = arg_parser.get_str("a_layout"); + const std::string b_layout = arg_parser.get_str("b_layout"); + // const ck_tile::index_t act = arg_parser.get_int("act"); + // const ck_tile::index_t gate_only = arg_parser.get_int("gate_only"); + // const ck_tile::index_t fused_quant = arg_parser.get_int("fquant"); using Row = ck_tile::tensor_layout::gemm::RowMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor; if(a_layout == "R" && b_layout == "C") { - return run_moe_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{}); + using Traits = MoeGemmHostTraits; + + return run_moe_gemm_example_with_layouts(argc, argv); } // else if(a_layout == "R" && b_layout == "R") // { diff --git a/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp b/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp index 9f34db87f1..b728b65427 100644 --- a/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp +++ b/include/ck_tile/host/reference/reference_fused_single_moe_gemm.hpp @@ -11,66 +11,6 @@ namespace ck_tile { -// template -// CK_TILE_HOST void reference_gemm(const HostTensor& a_m_k, -// const HostTensor& b_k_n, -// HostTensor& c_m_n, -// const AElementOp& a_element_op = {}, -// const BElementOp& b_element_op = {}, -// const ACCElementOp& acc_element_op = {}) -// { -// const std::size_t M = a_m_k.get_length(0); -// const std::size_t N = b_k_n.get_length(1); -// const std::size_t K = a_m_k.get_length(1); - -// auto f_mn = [&](auto m, auto n) { -// AccDataType v_acc = 0; - -// for(std::size_t k = 0; k < K; ++k) -// { -// AccDataType v_a; -// AccDataType v_b; -// if constexpr(std::is_same_v) -// { -// const pk_int4_t pk_val = a_element_op(a_m_k(m, k)); -// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); -// if(k % 2 == 1) -// v_a = fp32_val.hi; -// else -// v_a = fp32_val.lo; -// } -// else -// { -// v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); -// } -// if constexpr(std::is_same_v) -// { -// const pk_int4_t pk_val = b_element_op(b_k_n(k, n)); -// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val); -// if(k % 2 == 1) -// v_b = fp32_val.hi; -// else -// v_b = fp32_val.lo; -// } -// else -// { -// v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); -// } -// v_acc += v_a * v_b; -// } - -// c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); -// }; - -// make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); -// } - template + bool IsInputGemm = true, + bool IsGateOnly = true, + index_t GateActivation = 0> __global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_, const ck_tile::index_t* p_sorted_expert_ids_, const ck_tile::index_t* p_max_token_id_, @@ -192,7 +134,9 @@ template + bool IsInputGemm = true, + bool IsGateOnly = true, + index_t GateActivation = 0> void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_, const index_t* p_sorted_expert_ids_, const index_t* p_max_token_id_, diff --git a/include/ck_tile/ops/moe_gemm.hpp b/include/ck_tile/ops/moe_gemm.hpp index 7819cb412e..2cac0d8f44 100644 --- a/include/ck_tile/ops/moe_gemm.hpp +++ b/include/ck_tile/ops/moe_gemm.hpp @@ -5,11 +5,11 @@ #include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" -#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp" #include "ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp" #include "ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp" #include "ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm_policy.hpp" +#include "ck_tile/ops/moe_gemm/pipeline/tile_moe_gemm_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp b/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp index 622b323839..9652b04348 100644 --- a/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp +++ b/include/ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp @@ -53,8 +53,7 @@ struct MoeGemmHostArgs : public ck_tile::FlatmmHostArgs template + typename EpiloguePipeline_> struct MoeGemmKernel { using TilePartitioner = remove_cvref_t; @@ -66,7 +65,7 @@ struct MoeGemmKernel using BlockGemmShape = remove_cvref_t; // TileFlatmmShape - static constexpr bool IsInputGemm = IsInputGemm_; + static constexpr bool IsInputGemm = FlatmmPipeline::IsInputGemm; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -635,7 +634,8 @@ struct MoeGemmKernel // Run Epilogue Pipeline auto& c_block_window = gemm_tile_windows.at(number<2>{}); - EpiloguePipeline{}.template operator()( + EpiloguePipeline{}.template operator()( c_block_window, c_block_tile, smem_ptr_0, diff --git a/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp b/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp index 6b807862e0..80b8bd8074 100644 --- a/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp +++ b/include/ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp @@ -29,11 +29,17 @@ struct MoeGemmPipelineAgBgCrImpl using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; + using GateActivation = remove_cvref_t; + using BlockFlatmm = remove_cvref_t())>; using I0 = number<0>; using I1 = number<1>; using I2 = number<2>; + static constexpr bool IsInputGemm = Problem::Traits::IsInputGemm; + static constexpr bool IsGateOnly = Problem::Traits::IsGateOnly; + static constexpr bool IsFusedQuant = Problem::Traits::IsFusedQuant; + static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t kMPerBlock = BlockGemmShape::kM; @@ -113,32 +119,6 @@ struct MoeGemmPipelineAgBgCrImpl auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); - // auto a_dist = PipelinePolicy::template MakeADramTileDistribution(); - // auto a_coord = a_dist.calculate_index(); - // using ADstrEncode = typename decltype(a_dist)::DstrEncode; - // constexpr ck_tile::index_t MRepeat = ADstrEncode::hs_lengthss_[I0][I0]; - // statically_indexed_array a_offsets; - // static_for<0, MRepeat, 1>{}([&](auto n0) { - // int32_t seqlen_k_idx_per_repeat = cur_seqlen_k_idx + k_coord[0] + Traits::kBlockN / NRepeat * n0.value; - // int32_t page_idx = seqlen_k_idx_per_repeat / page_block_size; - // int32_t seq_idx = seqlen_k_idx_per_repeat % page_block_size; - // k_offsets[n0] = (block_indices[page_idx] * page_block_size + seq_idx) * stride_s_k; - // }); - // - // // A DRAM tile window for load - // auto a_dram_tile = ck_tile::make_tile_scatter_gather( - // a_dram_block_window_tmp.get_bottom_tensor_view(), - // a_dram_block_window_tmp.get_window_lengths(), - // a_dram_block_window_tmp.get_window_origin(), - // a_dist, - // k_offsets); // K DRAM tile window for - - // auto a_copy_dram_window = - // make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - // make_tuple(number{}, number{}), - // a_dram_block_window_tmp.get_window_origin(), - // PipelinePolicy::template MakeADramTileDistribution()); - // A LDS tile window for store auto a_copy_lds_window = make_tile_window( a_lds_block, make_tuple(number{}, number{}), {0, 0}); @@ -223,6 +203,15 @@ struct MoeGemmPipelineAgBgCrImpl block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window); } + sweep_tile(c_block_tile, + [&](auto idx0, auto idx1) { + fp32x2_t v_{c_block_tile(idx0), c_block_tile(idx1)}; + GateActivation{}(v_, v_); + c_block_tile(idx0) = v_.x; + c_block_tile(idx1) = v_.y; + }, + sequence<1, 2>{}); + return c_block_tile; } @@ -240,6 +229,163 @@ struct MoeGemmPipelineAgBgCrImpl p_smem); } + template + CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window, + const AElementFunction& a_element_func, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t N, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindow{}.get_window_lengths()[number<0>{}], + "wrong!"); + static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = + PipelinePolicy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + // A LDS tile window for store + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_flatmm = BlockFlatmm(); + + // B flat DRAM window for load + auto b_flat_distribution = + PipelinePolicy::template MakeBFlatDramTileDistribution(); + + auto b_gate_flat_dram_window = + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + b_flat_dram_block_window_tmp.move({N, 0}) + auto b_up_flat_dram_window = + make_tile_window( + b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views + make_tuple(number{}, number{}), + b_flat_dram_block_window_tmp.get_window_origin(), + b_flat_distribution); + + using c_block_tile_type = decltype(block_flatmm(a_lds_gemm_window, b_gate_flat_dram_window)); + auto c_block_tiles[2] = {c_block_tile_type{}, c_block_tile_type{}}; + + // prefetch + // global read 0 + auto a_block_tile = a_dram_block_window.load(); + + { + // move to 1 + move_tile_window(a_dram_block_window, {0, kKPerBlock}); + + // initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[0]); + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[1]); + + // LDS write 0 + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + PipelinePolicy::template MakeShuffledARegBlockDistribution()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } + } + + index_t iCounter = num_loop - 1; + while(iCounter > 0) + { + // global read i + 1 + a_dram_block_window.load(a_block_tile); + + block_sync_lds(); + + // GEMM i + block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window); + + //TODO: simply add b_gate flatmm + block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window); + + block_sync_lds(); + + // move to i + 2 + move_tile_window(a_dram_block_window, {0, kKPerBlock}); + + // LDS write i + 1 + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); + store_tile(a_copy_lds_window, a_block_tile_tmp); + + // move to next flat K + move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); + + iCounter--; + } + + // tail + { + block_sync_lds(); + + // GEMM num_loop - 1 + block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window); + block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window); + } + + sweep_tile(c_block_tiles[0], + [&](auto idx0, auto idx1) { + fp32x2_t v_{c_block_tiles[0].at(number<0>{})(idx0), c_block_tiles[0].at(number<0>{})(idx1)}; + typename Problem::GateActivation{}(v_, v_); + c_block_tiles[0].at(number<0>{})(idx0) = v_.x; + c_block_tiles[0].at(number<0>{})(idx1) = v_.y; + }, + sequence<1, 2>{}); + + auto c_block_tile = + tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; }, + c_block_tiles[0], + c_block_tiles[1]); + + return c_block_tiles[0]; + } + + template + CK_TILE_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window_tmp, + const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, + index_t N, + index_t num_loop, + void* p_smem) const + { + return operator()( + a_dram_block_window_tmp, + [](const ADataType& a) { return a; }, + b_flat_dram_block_window_tmp, + N, + num_loop, + p_smem); + } + }; } // namespace ck_tile diff --git a/include/ck_tile/ops/moe_gemm/pipeline/tile_moe_gemm_traits.hpp b/include/ck_tile/ops/moe_gemm/pipeline/tile_moe_gemm_traits.hpp new file mode 100644 index 0000000000..c705a3154d --- /dev/null +++ b/include/ck_tile/ops/moe_gemm/pipeline/tile_moe_gemm_traits.hpp @@ -0,0 +1,44 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct TileMoeGemmTraits +{ + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr bool IsInputGemm = IsInputGemm_; + static constexpr bool IsGateOnly = IsGateOnly_; + static constexpr bool IsFusedQuant = IsFusedQuant_; + + // TODO this can't be hardcoded here! Should be in policy! + static constexpr int _VectorSize = 16; + + using ALayout = ALayout_; + using BLayout = BLayout_; + using CLayout = CLayout_; + + using GateActivation = remove_cvref_t; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; +}; + + +} // namespace ck_tile