mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add activation and gate-up
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cwchar>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
@@ -31,6 +32,23 @@ using CDataType = Types::CDataType;
|
||||
|
||||
using moe_gemm_kargs = ck_tile::MoeGemmHostArgs;
|
||||
|
||||
template <typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
ck_tile::index_t activation_ = 0,
|
||||
bool gate_only_ = false,
|
||||
bool fused_quant_ = false>
|
||||
struct MoeGemmHostTraits
|
||||
{
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
static constexpr ck_tile::index_t activation = activation_;
|
||||
static constexpr bool IsGateOnly = gate_only_;
|
||||
static constexpr bool IsFusedQuant = fused_quant_;
|
||||
};
|
||||
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
@@ -48,6 +66,10 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default.")
|
||||
.insert("b_layout", "C", "B tensor data layout - Col by default.")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default.")
|
||||
.insert("act", "0", "activation after first gemm. 0:gelu, 1:silu")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert(
|
||||
"gate_only", "1", "w0(gate/up) style, 0:gate+up will double interm size, 1:only gate")
|
||||
.insert("validate", "1", "0. No validation, 1. Validation on CPU.")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("repeat", "10", "number of iterations to benchmark the kernel.");
|
||||
|
||||
@@ -38,7 +38,7 @@ struct MoeGemmKernelParam
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
};
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename Traits>
|
||||
float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s)
|
||||
{
|
||||
using CodegenMoeGemmShape = ck_tile::TileFlatmmShape<
|
||||
@@ -54,19 +54,33 @@ float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s)
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenMoeGemmShape>;
|
||||
|
||||
using CodegenMoeGemmTraits = ck_tile::TileGemmTraits<MoeGemmKernelParam::kPadM,
|
||||
constexpr auto get_activation_ = []() {
|
||||
if constexpr(Traits::activation == 0)
|
||||
{
|
||||
return ck_tile::element_wise::FastGeluAsm{};
|
||||
}
|
||||
else
|
||||
return ck_tile::element_wise::Silu{};
|
||||
};
|
||||
|
||||
using CodegenMoeGemmTraits = ck_tile::TileMoeGemmTraits<MoeGemmKernelParam::kPadM,
|
||||
MoeGemmKernelParam::kPadN,
|
||||
MoeGemmKernelParam::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
true,
|
||||
Traits::IsGateOnly,
|
||||
Traits::IsFusedQuant,
|
||||
typename Traits::ALayout,
|
||||
typename Traits::BLayout,
|
||||
typename Traits::CLayout,
|
||||
decltype(get_activation_())>;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenMoeGemmShape,
|
||||
CodegenMoeGemmTraits>;
|
||||
CodegenMoeGemmTraits,
|
||||
AccDataType>;
|
||||
|
||||
using CodegenMoeGemmPolicy = ck_tile::UniversalFlatmmPipelineAgBgCrPolicy;
|
||||
using CodegenMoeGemmPipeline =
|
||||
@@ -77,7 +91,7 @@ float moe_gemm(const moe_gemm_kargs& gemm_desc, const ck_tile::stream_config& s)
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
CLayout,
|
||||
typename Traits::CLayout,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
|
||||
@@ -70,10 +70,10 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename Traits>
|
||||
float invoke_moe_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args)
|
||||
{
|
||||
float ave_time = moe_gemm<ALayout, BLayout, CLayout>(
|
||||
float ave_time = moe_gemm<Traits>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::string op_name{"Moe Gemm"};
|
||||
@@ -94,12 +94,9 @@ float invoke_moe_gemm(int n_warmup, int n_repeat, const moe_gemm_kargs& args)
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename Traits>
|
||||
int run_moe_gemm_example_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
|
||||
@@ -129,6 +126,11 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
const ck_tile::index_t topk = arg_parser.get_int("TopK");
|
||||
const ck_tile::index_t repeat = arg_parser.get_int("repeat");
|
||||
const ck_tile::index_t experts = arg_parser.get_int("experts");
|
||||
const std::string mfma = arg_parser.get_str("prec");
|
||||
|
||||
auto a_layout = typename Traits::ALayout{};
|
||||
auto b_layout = typename Traits::BLayout{};
|
||||
auto c_layout = typename Traits::CLayout{};
|
||||
|
||||
// TODO: replace the magic declaration
|
||||
const ck_tile::index_t MPerBlock = 128;
|
||||
@@ -153,7 +155,7 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
|
||||
stride_A = ck_tile::get_default_stride(num_tokens, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(num_tokens * topk, N, stride_C, is_row_major(CLayout{}));
|
||||
stride_C = ck_tile::get_default_stride(num_tokens * topk, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
auto a_m_k_tensor = ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(num_tokens, K, stride_A, is_row_major(a_layout)));
|
||||
@@ -164,10 +166,9 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
? ck_tile::host_tensor_descriptor(experts * N, K, stride_B, is_row_major(b_layout))
|
||||
: ck_tile::host_tensor_descriptor(K, experts * N, stride_B, is_row_major(b_layout)));
|
||||
|
||||
std::string mfma = arg_parser.get_str("prec");
|
||||
|
||||
auto c_m_n_tensor = ck_tile::HostTensor<CDataType>(
|
||||
ck_tile::host_tensor_descriptor(num_tokens * topk, N, stride_C, is_row_major(CLayout{})));
|
||||
ck_tile::host_tensor_descriptor(num_tokens * topk, N, stride_C, is_row_major(c_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensor);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensor);
|
||||
@@ -268,7 +269,7 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
stride_B,
|
||||
stride_C};
|
||||
|
||||
invoke_moe_gemm<ALayout, BLayout, CLayout>(3, repeat, gemm_desc);
|
||||
invoke_moe_gemm<Traits>(3, repeat, gemm_desc);
|
||||
|
||||
c_m_n_dev_buf->FromDevice(c_m_n_tensor.data());
|
||||
|
||||
@@ -276,7 +277,7 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
if(arg_parser.get_int("validate"))
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
num_tokens * topk, N, stride_C, is_row_major(CLayout{})));
|
||||
num_tokens * topk, N, stride_C, is_row_major(c_layout)));
|
||||
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
@@ -289,9 +290,9 @@ int run_moe_gemm_example_with_layouts(int argc,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(
|
||||
typename Traits::ALayout,
|
||||
typename Traits::BLayout,
|
||||
typename Traits::CLayout>(
|
||||
p_sorted_token_ids_dev,
|
||||
p_expert_ids_dev,
|
||||
p_max_token_id_dev,
|
||||
@@ -353,15 +354,20 @@ int run_moe_gemm_example(int argc, char* argv[])
|
||||
return -1;
|
||||
}
|
||||
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
const std::string a_layout = arg_parser.get_str("a_layout");
|
||||
const std::string b_layout = arg_parser.get_str("b_layout");
|
||||
// const ck_tile::index_t act = arg_parser.get_int("act");
|
||||
// const ck_tile::index_t gate_only = arg_parser.get_int("gate_only");
|
||||
// const ck_tile::index_t fused_quant = arg_parser.get_int("fquant");
|
||||
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_moe_gemm_example_with_layouts(argc, argv, Row{}, Col{}, Row{});
|
||||
using Traits = MoeGemmHostTraits<Row, Col, Row, 0, false, false>;
|
||||
|
||||
return run_moe_gemm_example_with_layouts<Traits>(argc, argv);
|
||||
}
|
||||
// else if(a_layout == "R" && b_layout == "R")
|
||||
// {
|
||||
|
||||
@@ -11,66 +11,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// template <typename ADataType,
|
||||
// typename BDataType,
|
||||
// typename AccDataType,
|
||||
// typename CDataType,
|
||||
// typename AElementOp = ck_tile::identity,
|
||||
// typename BElementOp = ck_tile::identity,
|
||||
// typename ACCElementOp = ck_tile::identity>
|
||||
// CK_TILE_HOST void reference_gemm(const HostTensor<ADataType>& a_m_k,
|
||||
// const HostTensor<BDataType>& b_k_n,
|
||||
// HostTensor<CDataType>& c_m_n,
|
||||
// const AElementOp& a_element_op = {},
|
||||
// const BElementOp& b_element_op = {},
|
||||
// const ACCElementOp& acc_element_op = {})
|
||||
// {
|
||||
// const std::size_t M = a_m_k.get_length(0);
|
||||
// const std::size_t N = b_k_n.get_length(1);
|
||||
// const std::size_t K = a_m_k.get_length(1);
|
||||
|
||||
// auto f_mn = [&](auto m, auto n) {
|
||||
// AccDataType v_acc = 0;
|
||||
|
||||
// for(std::size_t k = 0; k < K; ++k)
|
||||
// {
|
||||
// AccDataType v_a;
|
||||
// AccDataType v_b;
|
||||
// if constexpr(std::is_same_v<ADataType, pk_int4_t>)
|
||||
// {
|
||||
// const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
|
||||
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
// if(k % 2 == 1)
|
||||
// v_a = fp32_val.hi;
|
||||
// else
|
||||
// v_a = fp32_val.lo;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
|
||||
// }
|
||||
// if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
// {
|
||||
// const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
|
||||
// const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
|
||||
// if(k % 2 == 1)
|
||||
// v_b = fp32_val.hi;
|
||||
// else
|
||||
// v_b = fp32_val.lo;
|
||||
// }
|
||||
// else
|
||||
// {
|
||||
// v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
|
||||
// }
|
||||
// v_acc += v_a * v_b;
|
||||
// }
|
||||
|
||||
// c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
|
||||
// };
|
||||
|
||||
// make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
|
||||
// }
|
||||
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
@@ -78,7 +18,9 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
bool IsInputGemm = true>
|
||||
bool IsInputGemm = true,
|
||||
bool IsGateOnly = true,
|
||||
index_t GateActivation = 0>
|
||||
__global__ void naive_gemm_kernel(const ck_tile::index_t* p_sorted_token_ids_,
|
||||
const ck_tile::index_t* p_sorted_expert_ids_,
|
||||
const ck_tile::index_t* p_max_token_id_,
|
||||
@@ -192,7 +134,9 @@ template <typename ADataType,
|
||||
typename LayoutA,
|
||||
typename LayoutB,
|
||||
typename LayoutC,
|
||||
bool IsInputGemm = true>
|
||||
bool IsInputGemm = true,
|
||||
bool IsGateOnly = true,
|
||||
index_t GateActivation = 0>
|
||||
void reference_moe_gemm_gpu(const index_t* p_sorted_token_ids_,
|
||||
const index_t* p_sorted_expert_ids_,
|
||||
const index_t* p_max_token_id_,
|
||||
|
||||
@@ -5,11 +5,11 @@
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
|
||||
|
||||
#include "ck_tile/ops/moe_gemm/kernel/moe_gemm_kernel.hpp"
|
||||
#include "ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm.hpp"
|
||||
#include "ck_tile/ops/moe_gemm/pipeline/moe_gemm_pipeline_agmem_bgmem_creg_flatmm_policy.hpp"
|
||||
#include "ck_tile/ops/moe_gemm/pipeline/tile_moe_gemm_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/common/utils.hpp"
|
||||
|
||||
@@ -53,8 +53,7 @@ struct MoeGemmHostArgs : public ck_tile::FlatmmHostArgs
|
||||
|
||||
template <typename TilePartitioner_,
|
||||
typename FlatmmPipeline_,
|
||||
typename EpiloguePipeline_,
|
||||
bool IsInputGemm_ = true>
|
||||
typename EpiloguePipeline_>
|
||||
struct MoeGemmKernel
|
||||
{
|
||||
using TilePartitioner = remove_cvref_t<TilePartitioner_>;
|
||||
@@ -66,7 +65,7 @@ struct MoeGemmKernel
|
||||
using BlockGemmShape =
|
||||
remove_cvref_t<typename FlatmmPipeline::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
static constexpr bool IsInputGemm = IsInputGemm_;
|
||||
static constexpr bool IsInputGemm = FlatmmPipeline::IsInputGemm;
|
||||
|
||||
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>;
|
||||
@@ -635,7 +634,8 @@ struct MoeGemmKernel
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(number<2>{});
|
||||
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
|
||||
EpiloguePipeline{}.template operator()<decltype(c_block_window),
|
||||
decltype(c_block_tile)>(
|
||||
c_block_window,
|
||||
c_block_tile,
|
||||
smem_ptr_0,
|
||||
|
||||
@@ -29,11 +29,17 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using GateActivation = remove_cvref_t<typename Problem::Traits::GateActivation>;
|
||||
|
||||
using BlockFlatmm = remove_cvref_t<decltype(PipelinePolicy::template GetBlockFlatmm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr bool IsInputGemm = Problem::Traits::IsInputGemm;
|
||||
static constexpr bool IsGateOnly = Problem::Traits::IsGateOnly;
|
||||
static constexpr bool IsFusedQuant = Problem::Traits::IsFusedQuant;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
@@ -113,32 +119,6 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// auto a_dist = PipelinePolicy::template MakeADramTileDistribution<Problem>();
|
||||
// auto a_coord = a_dist.calculate_index();
|
||||
// using ADstrEncode = typename decltype(a_dist)::DstrEncode;
|
||||
// constexpr ck_tile::index_t MRepeat = ADstrEncode::hs_lengthss_[I0][I0];
|
||||
// statically_indexed_array<ck_tile::index_t, NRepeat> a_offsets;
|
||||
// static_for<0, MRepeat, 1>{}([&](auto n0) {
|
||||
// int32_t seqlen_k_idx_per_repeat = cur_seqlen_k_idx + k_coord[0] + Traits::kBlockN / NRepeat * n0.value;
|
||||
// int32_t page_idx = seqlen_k_idx_per_repeat / page_block_size;
|
||||
// int32_t seq_idx = seqlen_k_idx_per_repeat % page_block_size;
|
||||
// k_offsets[n0] = (block_indices[page_idx] * page_block_size + seq_idx) * stride_s_k;
|
||||
// });
|
||||
//
|
||||
// // A DRAM tile window for load
|
||||
// auto a_dram_tile = ck_tile::make_tile_scatter_gather(
|
||||
// a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
// a_dram_block_window_tmp.get_window_lengths(),
|
||||
// a_dram_block_window_tmp.get_window_origin(),
|
||||
// a_dist,
|
||||
// k_offsets); // K DRAM tile window for
|
||||
|
||||
// auto a_copy_dram_window =
|
||||
// make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
// make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
// a_dram_block_window_tmp.get_window_origin(),
|
||||
// PipelinePolicy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
@@ -223,6 +203,15 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
block_flatmm(c_block_tile, a_lds_gemm_window, b_flat_dram_window);
|
||||
}
|
||||
|
||||
sweep_tile(c_block_tile,
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{c_block_tile(idx0), c_block_tile(idx1)};
|
||||
GateActivation{}(v_, v_);
|
||||
c_block_tile(idx0) = v_.x;
|
||||
c_block_tile(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
|
||||
@@ -240,6 +229,163 @@ struct MoeGemmPipelineAgBgCrImpl
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp, typename AElementFunction>
|
||||
CK_TILE_HOST_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window,
|
||||
const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t N,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindow::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindow{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindow{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// A LDS tile window for store
|
||||
auto a_copy_lds_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// A LDS tile for block GEMM
|
||||
auto a_lds_gemm_window = make_tile_window(
|
||||
a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_flatmm = BlockFlatmm();
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
|
||||
auto b_gate_flat_dram_window =
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
b_flat_dram_block_window_tmp.move({N, 0})
|
||||
auto b_up_flat_dram_window =
|
||||
make_tile_window(
|
||||
b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp>{}, number<flatKPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
using c_block_tile_type = decltype(block_flatmm(a_lds_gemm_window, b_gate_flat_dram_window));
|
||||
auto c_block_tiles[2] = {c_block_tile_type{}, c_block_tile_type{}};
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
auto a_block_tile = a_dram_block_window.load();
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(a_dram_block_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[0]);
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tiles[1]);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
PipelinePolicy::template MakeShuffledARegBlockDistribution<Problem>());
|
||||
shuffle_tile(a_shuffle_tmp, a_block_tile);
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
|
||||
}
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
a_dram_block_window.load(a_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM i
|
||||
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
|
||||
|
||||
//TODO: simply add b_gate flatmm
|
||||
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(a_dram_block_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
|
||||
store_tile(a_copy_lds_window, a_block_tile_tmp);
|
||||
|
||||
// move to next flat K
|
||||
move_tile_window(b_gate_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
move_tile_window(b_up_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// GEMM num_loop - 1
|
||||
block_flatmm(c_block_tiles[0], a_lds_gemm_window, b_gate_flat_dram_window);
|
||||
block_flatmm(c_block_tiles[1], a_lds_gemm_window, b_up_flat_dram_window);
|
||||
}
|
||||
|
||||
sweep_tile(c_block_tiles[0],
|
||||
[&](auto idx0, auto idx1) {
|
||||
fp32x2_t v_{c_block_tiles[0].at(number<0>{})(idx0), c_block_tiles[0].at(number<0>{})(idx1)};
|
||||
typename Problem::GateActivation{}(v_, v_);
|
||||
c_block_tiles[0].at(number<0>{})(idx0) = v_.x;
|
||||
c_block_tiles[0].at(number<0>{})(idx1) = v_.y;
|
||||
},
|
||||
sequence<1, 2>{});
|
||||
|
||||
auto c_block_tile =
|
||||
tile_elementwise_in([&](const auto& a_, const auto& b_) { return a_ * b_; },
|
||||
c_block_tiles[0],
|
||||
c_block_tiles[1]);
|
||||
|
||||
return c_block_tiles[0];
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindow, typename BFlatBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(ADramBlockWindow& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t N,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
N,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool IsInputGemm_,
|
||||
bool IsGateOnly_,
|
||||
bool IsFusedQuant_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
typename GateActivation_>
|
||||
struct TileMoeGemmTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
static constexpr bool IsInputGemm = IsInputGemm_;
|
||||
static constexpr bool IsGateOnly = IsGateOnly_;
|
||||
static constexpr bool IsFusedQuant = IsFusedQuant_;
|
||||
|
||||
// TODO this can't be hardcoded here! Should be in policy!
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
using GateActivation = remove_cvref_t<GateActivation_>;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
};
|
||||
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user