mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
* moe pipeline * update code * compile OK * update * update cpu reference * update pipeline_gemm0 * compiler ok * update pipeline * rename to ex pipeline * block-asm * update * update * update first gemm ok * compute correct * update file structure * update README * update * update * update code * update API * return unsupport case * add comment * update readme * update * uncomment * update * fix build err --------- Co-authored-by: valarLip <340077269@qq.com>
53 lines
2.2 KiB
C++
53 lines
2.2 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "fused_moesorting.hpp"
|
|
#include "fused_moegemm.hpp"
|
|
|
|
struct fused_moe_args
|
|
{
|
|
const void* a_ptr; // [m, k], input token
|
|
const void* a_scale_ptr; // [m, 1], token scale
|
|
const void* g_ptr; // [e, n, k]/[e, 2*n, k], pre-shuffle([e, nr, kr, w])
|
|
const void* d_ptr; // [e, n, k], pre-shuffle([e, nr, kr, w])
|
|
const void* g_scale_ptr; // [e, 1, n], gate(up) scale
|
|
const void* d_scale_ptr; // [e, 1, k], down scale
|
|
const void* y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
|
|
void* o_ptr; // [m, k], output token (no need to do zeroing)
|
|
|
|
const void* topk_ids_ptr; // [tokens, topk]
|
|
const void* topk_weight_ptr; // [tokens, topk]
|
|
void* sorted_token_ids_ptr; // [max_num_tokens_padded]
|
|
void* sorted_weight_ptr; // [max_num_tokens_padded]
|
|
void* sorted_expert_ids_ptr; // [(max_num_tokens_padded + block_size - 1) / block_size]
|
|
void* num_sorted_tiles_ptr; // [1]
|
|
|
|
ck_tile::index_t block_m; // block_m, used to devide the input
|
|
ck_tile::index_t hidden_size; // k
|
|
ck_tile::index_t intermediate_size; // n / TP, for Gate. if Gate+Up, Down need divide by 2
|
|
ck_tile::index_t num_tokens; // input number of tokens for current iteration
|
|
ck_tile::index_t num_experts; // number of groups
|
|
ck_tile::index_t topk; // need this?
|
|
|
|
ck_tile::index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
|
};
|
|
|
|
// This is the public API, will be generated by script
|
|
struct fused_moe_traits
|
|
{
|
|
std::string prec_i; // input precision
|
|
std::string prec_w; // weight precision
|
|
std::string prec_o; // output precision
|
|
std::string prec_st; // token scale data type
|
|
std::string prec_sw; // weight scale data type
|
|
std::string prec_sq; // smooth quant scale
|
|
std::string prec_kw; // topk-weight data type
|
|
int block_m;
|
|
int gate_only;
|
|
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
|
};
|
|
|
|
float fused_moe(fused_moe_traits, fused_moe_args, const ck_tile::stream_config&);
|