add activation and gate-up

This commit is contained in:
zanzhang
2025-04-28 09:52:56 +08:00
parent 36d0c9d8e2
commit e030c34211
8 changed files with 294 additions and 118 deletions

View File

@@ -3,6 +3,7 @@
#pragma once
#include <cwchar>
#include <string>
#include "ck_tile/core.hpp"
@@ -31,6 +32,23 @@ using CDataType = Types::CDataType;
using moe_gemm_kargs = ck_tile::MoeGemmHostArgs;
template <typename ALayout_,
typename BLayout_,
typename CLayout_,
ck_tile::index_t activation_ = 0,
bool gate_only_ = false,
bool fused_quant_ = false>
struct MoeGemmHostTraits
{
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr ck_tile::index_t activation = activation_;
static constexpr bool IsGateOnly = gate_only_;
static constexpr bool IsFusedQuant = fused_quant_;
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
@@ -48,6 +66,10 @@ auto create_args(int argc, char* argv[])
.insert("a_layout", "R", "A tensor data layout - Row by default.")
.insert("b_layout", "C", "B tensor data layout - Col by default.")
.insert("c_layout", "R", "C tensor data layout - Row by default.")
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
.insert(
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("repeat", "10", "number of iterations to benchmark the kernel.");

View File

@@ -38,7 +38,7 @@ struct MoeGemmKernelParam
static const ck_tile::index_t K_Warp_Tile = 16;
};
template <typename ALayout, typename BLayout, typename CLayout>
template <typename Traits>
float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s)
{
using CodegenMoeGemmShape = ck_tile::TileFlatmmShape<
@@ -54,19 +54,33 @@ float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s)
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenMoeGemmShape>;
using CodegenMoeGemmTraits = ck_tile::TileGemmTraits<MoeGemmKernelParam::kPadM,
constexpr auto get_activation_ = []() {
if constexpr(Traits::activation == 0)
{
return ck_tile::element_wise::FastGeluAsm{};
}
else
return ck_tile::element_wise::Silu{};
};
using CodegenMoeGemmTraits = ck_tile::TileMoeGemmTraits<MoeGemmKernelParam::kPadM,
MoeGemmKernelParam::kPadN,
MoeGemmKernelParam::kPadK,
ALayout,
BLayout,
CLayout>;
true,
Traits::IsGateOnly,
Traits::IsFusedQuant,
typename Traits::ALayout,
typename Traits::BLayout,
typename Traits::CLayout,
decltype(get_activation_())>;
using CodegenPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType,
BDataType,
AccDataType,
CodegenMoeGemmShape,
CodegenMoeGemmTraits>;
CodegenMoeGemmTraits,
AccDataType>;
using CodegenMoeGemmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
using CodegenMoeGemmPipeline =
@@ -77,7 +91,7 @@ float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s)
BDataType,
AccDataType,
CDataType,
CLayout,
typename Traits::CLayout,
CodegenPipelineProblem::kBlockSize,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,

View File

@@ -70,10 +70,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename Traits>
float invoke_moe_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args)
{
float ave_time = moe_gemm<ALayout, BLayout, CLayout>(
float ave_time = moe_gemm<Traits>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::string op_name{"Moe Gemm"};
@@ -94,12 +94,9 @@ float invoke_moe_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args)
return ave_time;
}
template <typename ALayout, typename BLayout, typename CLayout>
template <typename Traits>
int run_moe_gemm_example_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
@@ -129,6 +126,11 @@ int run_moe_gemm_example_with_layouts(int argc,
const ck_tile::index_t topk = arg_parser.get_int("TopK");
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
const ck_tile::index_t experts = arg_parser.get_int("experts");
const std::string mfma = arg_parser.get_str("prec");
auto a_layout = typename Traits::ALayout{};
auto b_layout = typename Traits::BLayout{};
auto c_layout = typename Traits::CLayout{};
// TODO: replace the magic declaration
const ck_tile::index_t MPerBlock = 128;
@@ -153,7 +155,7 @@ int run_moe_gemm_example_with_layouts(int argc,
stride_A = ck_tile::get_default_stride(num_tokens, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(num_tokens * topk, N, stride_C, is_row_major(CLayout{}));
stride_C = ck_tile::get_default_stride(num_tokens * topk, N, stride_C, is_row_major(c_layout));
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(
ck_tile::host_tensor_descriptor(num_tokens, K, stride_A, is_row_major(a_layout)));
@@ -164,10 +166,9 @@ int run_moe_gemm_example_with_layouts(int argc,
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
std::string mfma = arg_parser.get_str("prec");
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(
ck_tile::host_tensor_descriptor(num_tokens * topk, N, stride_C, is_row_major(CLayout{})));
ck_tile::host_tensor_descriptor(num_tokens * topk, N, stride_C, is_row_major(c_layout)));
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensor);
@@ -268,7 +269,7 @@ int run_moe_gemm_example_with_layouts(int argc,
stride_B,
stride_C};
invoke_moe_gemm<ALayout, BLayout, CLayout>(3, repeat, gemm_desc);
invoke_moe_gemm<Traits>(3, repeat, gemm_desc);
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
@@ -276,7 +277,7 @@ int run_moe_gemm_example_with_layouts(int argc,
if(arg_parser.get_int("validate"))
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
num_tokens * topk, N, stride_C, is_row_major(CLayout{})));
num_tokens * topk, N, stride_C, is_row_major(c_layout)));
c_m_n_host_ref.SetZero();
@@ -289,9 +290,9 @@ int run_moe_gemm_example_with_layouts(int argc,
BDataType,
AccDataType,
CDataType,
ALayout,
BLayout,
CLayout>(
typename Traits::ALayout,
typename Traits::BLayout,
typename Traits::CLayout>(
p_sorted_token_ids_dev,
p_expert_ids_dev,
p_max_token_id_dev,
@@ -353,15 +354,20 @@ int run_moe_gemm_example(int argc, char* argv[])
return -1;
}
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
const std::string a_layout = arg_parser.get_str("a_layout");
const std::string b_layout = arg_parser.get_str("b_layout");
// const ck_tile::index_t act = arg_parser.get_int("act");
// const ck_tile::index_t gate_only = arg_parser.get_int("gate_only");
// const ck_tile::index_t fused_quant = arg_parser.get_int("fquant");
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if(a_layout == "R" && b_layout == "C")
{
return run_moe_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
using Traits = MoeGemmHostTraits<Row, Col, Row, 0, false, false>;
return run_moe_gemm_example_with_layouts<Traits>(argc, argv);
}
// else if(a_layout == "R" && b_layout == "R")
// {

View File

@@ -11,66 +11,6 @@
namespace ck_tile {
// template <typename ADataType,
// typename BDataType,
// typename AccDataType,
// typename CDataType,
// typename AElementOp = ck_tile::identity,
// typename BElementOp = ck_tile::identity,
// typename ACCElementOp = ck_tile::identity>
// CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
// const HostTensor<BDataType>& b_k_n,
// HostTensor<CDataType>& c_m_n,
// const AElementOp& a_element_op = {},
// const BElementOp& b_element_op = {},
// const ACCElementOp& acc_element_op = {})
// {
// const std::size_t M = a_m_k.get_length(0);
// const std::size_t N = b_k_n.get_length(1);
// const std::size_t K = a_m_k.get_length(1);
// auto f_mn = [&](auto m, auto n) {
// AccDataType v_acc = 0;
// for(std::size_t k = 0; k < K; ++k)
// {
// AccDataType v_a;
// AccDataType v_b;
// if constexpr(std::is_same_v<ADataType, pk_int4_t>)
// {
// const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
// if(k % 2 == 1)
// v_a = fp32_val.hi;
// else
// v_a = fp32_val.lo;
// }
// else
// {
// v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
// }
// if constexpr(std::is_same_v<BDataType, pk_int4_t>)
// {
// const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
// if(k % 2 == 1)
// v_b = fp32_val.hi;
// else
// v_b = fp32_val.lo;
// }
// else
// {
// v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
// }
// v_acc += v_a * v_b;
// }
// c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
// };
// make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
// }
template <typename ADataType,
typename BDataType,
typename AccDataType,
@@ -78,7 +18,9 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
bool IsInputGemm = true>
bool IsInputGemm = true,
bool IsGateOnly = true,
index_t GateActivation = 0>
__global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
const ck_tile::index_t* p_sorted_expert_ids_,
const ck_tile::index_t* p_max_token_id_,
@@ -192,7 +134,9 @@ template <typename ADataType,
typename LayoutA,
typename LayoutB,
typename LayoutC,
bool IsInputGemm = true>
bool IsInputGemm = true,
bool IsGateOnly = true,
index_t GateActivation = 0>
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
const index_t* p_sorted_expert_ids_,
const index_t* p_max_token_id_,

View File

@@ -5,11 +5,11 @@
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp"
#include "ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp"
#include "ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm_policy.hpp"
#include "ck_tile/ops/moe_gemm/pipeline/tile_moe_gemm_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -53,8 +53,7 @@ struct MoeGemmHostArgs : public ck_tile::FlatmmHostArgs
template <typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_,
bool IsInputGemm_ = true>
typename EpiloguePipeline_>
struct MoeGemmKernel
{
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
@@ -66,7 +65,7 @@ struct MoeGemmKernel
using BlockGemmShape =
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
static constexpr bool IsInputGemm = IsInputGemm_;
static constexpr bool IsInputGemm = FlatmmPipeline::IsInputGemm;
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
@@ -635,7 +634,8 @@ struct MoeGemmKernel
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(number<2>{});
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
EpiloguePipeline{}.template operator()<decltype(c_block_window),
decltype(c_block_tile)>(
c_block_window,
c_block_tile,
smem_ptr_0,

View File

@@ -29,11 +29,17 @@ struct MoeGemmPipelineAgBgCrImpl
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using GateActivation = remove_cvref_t<typename Problem::Traits::GateActivation>;
using BlockFlatmm = remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr bool IsInputGemm = Problem::Traits::IsInputGemm;
static constexpr bool IsGateOnly = Problem::Traits::IsGateOnly;
static constexpr bool IsFusedQuant = Problem::Traits::IsFusedQuant;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
@@ -113,32 +119,6 @@ struct MoeGemmPipelineAgBgCrImpl
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// auto a_dist = PipelinePolicy::template MakeADramTileDistribution<Problem>();
// auto a_coord = a_dist.calculate_index();
// using ADstrEncode = typename decltype(a_dist)::DstrEncode;
// constexpr ck_tile::index_t MRepeat = ADstrEncode::hs_lengthss_[I0][I0];
// statically_indexed_array<ck_tile::index_t, NRepeat> a_offsets;
// static_for<0, MRepeat, 1>{}([&](auto n0) {
// int32_t seqlen_k_idx_per_repeat = cur_seqlen_k_idx + k_coord[0] + Traits::kBlockN / NRepeat * n0.value;
// int32_t page_idx = seqlen_k_idx_per_repeat / page_block_size;
// int32_t seq_idx = seqlen_k_idx_per_repeat % page_block_size;
// k_offsets[n0] = (block_indices[page_idx] * page_block_size + seq_idx) * stride_s_k;
// });
//
// // A DRAM tile window for load
// auto a_dram_tile = ck_tile::make_tile_scatter_gather(
// a_dram_block_window_tmp.get_bottom_tensor_view(),
// a_dram_block_window_tmp.get_window_lengths(),
// a_dram_block_window_tmp.get_window_origin(),
// a_dist,
// k_offsets); // K DRAM tile window for
// auto a_copy_dram_window =
// make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
// make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
// a_dram_block_window_tmp.get_window_origin(),
// PipelinePolicy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
@@ -223,6 +203,15 @@ struct MoeGemmPipelineAgBgCrImpl
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
}
sweep_tile(c_block_tile,
[&](auto idx0, auto idx1) {
fp32x2_t v_{c_block_tile(idx0), c_block_tile(idx1)};
GateActivation{}(v_, v_);
c_block_tile(idx0) = v_.x;
c_block_tile(idx1) = v_.y;
},
sequence<1, 2>{});
return c_block_tile;
}
@@ -240,6 +229,163 @@ struct MoeGemmPipelineAgBgCrImpl
p_smem);
}
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp, typename AElementFunction>
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window,
const AElementFunction& a_element_func,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t N,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindow::DataType>>,
"wrong!");
static_assert(kMPerBlock == ADramBlockWindow{}.get_window_lengths()[number<0>{}],
"wrong!");
static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}],
"wrong!");
// A tile in LDS
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc =
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// A LDS tile window for store
auto a_copy_lds_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window(
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
// Block GEMM
auto block_flatmm = BlockFlatmm();
// B flat DRAM window for load
auto b_flat_distribution =
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
auto b_gate_flat_dram_window =
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
b_flat_dram_block_window_tmp.move({N, 0})
auto b_up_flat_dram_window =
make_tile_window(
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
b_flat_dram_block_window_tmp.get_window_origin(),
b_flat_distribution);
using c_block_tile_type = decltype(block_flatmm(a_lds_gemm_window, b_gate_flat_dram_window));
auto c_block_tiles[2] = {c_block_tile_type{}, c_block_tile_type{}};
// prefetch
// global read 0
auto a_block_tile = a_dram_block_window.load();
{
// move to 1
move_tile_window(a_dram_block_window, {0, kKPerBlock});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[0]);
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[1]);
// LDS write 0
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
}
index_t iCounter = num_loop - 1;
while(iCounter > 0)
{
// global read i + 1
a_dram_block_window.load(a_block_tile);
block_sync_lds();
// GEMM i
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
//TODO: simply add b_gate flatmm
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
block_sync_lds();
// move to i + 2
move_tile_window(a_dram_block_window, {0, kKPerBlock});
// LDS write i + 1
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window, a_block_tile_tmp);
// move to next flat K
move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
iCounter--;
}
// tail
{
block_sync_lds();
// GEMM num_loop - 1
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
}
sweep_tile(c_block_tiles[0],
[&](auto idx0, auto idx1) {
fp32x2_t v_{c_block_tiles[0].at(number<0>{})(idx0), c_block_tiles[0].at(number<0>{})(idx1)};
typename Problem::GateActivation{}(v_, v_);
c_block_tiles[0].at(number<0>{})(idx0) = v_.x;
c_block_tiles[0].at(number<0>{})(idx1) = v_.y;
},
sequence<1, 2>{});
auto c_block_tile =
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
c_block_tiles[0],
c_block_tiles[1]);
return c_block_tiles[0];
}
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp>
CK_TILE_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window_tmp,
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
index_t N,
index_t num_loop,
void* p_smem) const
{
return operator()(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_flat_dram_block_window_tmp,
N,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,44 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
bool IsInputGemm_,
bool IsGateOnly_,
bool IsFusedQuant_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
typename GateActivation_>
struct TileMoeGemmTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
static constexpr bool IsInputGemm = IsInputGemm_;
static constexpr bool IsGateOnly = IsGateOnly_;
static constexpr bool IsFusedQuant = IsFusedQuant_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
using GateActivation = remove_cvref_t<GateActivation_>;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
};
} // namespace ck_tile