mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add moe general
This commit is contained in:
@@ -19,13 +19,13 @@ float fused_moegemm(fused_moegemm_traits t, fused_moegemm_args a, const ck_tile:
|
||||
if(t.prec_i == "bf16" && t.prec_w == "bf16" && t.prec_o == "bf16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
using t_ = fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
else if(t.prec_i == "fp16" && t.prec_w == "fp16" && t.prec_o == "fp16" && t.prec_st == "fp32" &&
|
||||
t.prec_sw == "fp32" && t.prec_sq == "fp32" && t.prec_kw == "fp32" && t.block_m == 32 && t.gate_only == 1)
|
||||
{
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>;
|
||||
using t_ = fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>;
|
||||
r = fused_moegemm_<t_>(s, a);
|
||||
}
|
||||
// clang-format on
|
||||
|
||||
@@ -19,8 +19,8 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0,
|
||||
typename Ts_::BlockTile_1,
|
||||
typename Ts_::WarpPerBlock_0,
|
||||
typename Ts_::WarpTile_0>;
|
||||
typename Ts_::WarpPerBlock_1,
|
||||
typename Ts_::WarpTile_1>;
|
||||
using f_problem =
|
||||
ck_tile::FusedMoeGemmPipelineProblem<typename Ts_::ADataType,
|
||||
typename Ts_::GDataType,
|
||||
@@ -38,9 +38,9 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
|
||||
f_traits>;
|
||||
|
||||
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
|
||||
using f_pipeline = ck_tile::FusedMoeGemmPipeline_General<f_problem>;
|
||||
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
|
||||
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
|
||||
using f_kernel = ck_tile::FusedMoeGemmGlKernel<f_partitioner, f_pipeline, void>;
|
||||
|
||||
const dim3 grids = f_kernel::GridSize(a);
|
||||
constexpr dim3 blocks = f_kernel::BlockSize();
|
||||
|
||||
@@ -44,8 +44,8 @@ struct fmoe_ // traits, ugly name, only used for internal
|
||||
using WarpPerBlock_0 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_0 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_ / (GateOnly_ ? 1 : 2)>;
|
||||
using WarpPerBlock_1 = ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using BlockTile_1 = ck_tile::sequence<BT_, BD_, BI_>;
|
||||
using WarpPerBlock_1 = ck_tile::sequence<1, 1, 4>;//ck_tile::remove_cvref_t<WarpPerBlock_>;
|
||||
using WarpTile_1 = ck_tile::remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr ck_tile::index_t GateOnly = GateOnly_;
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
|
||||
// clang-format off
|
||||
template float fused_moegemm_<
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 512, 128, 128>, S<1, 4, 1>, S<16, 16, 32>, 1, 0>
|
||||
fmoe_<ck_tile::fp16_t, ck_tile::fp16_t, ck_tile::fp16_t, float, float, float, float, S<32, 128, 32, 32>, S<1, 4, 1>, S<32, 32, 8>, 1, 0>
|
||||
>(const ck_tile::stream_config& s, fused_moegemm_args a);
|
||||
|
||||
// clang-format on
|
||||
|
||||
@@ -261,8 +261,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
}
|
||||
|
||||
// permute weight
|
||||
ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
// ck_tile::HostTensor<GDataType> g_perm_host = shuffle_moe_weight(g_host, prec_w, 1);
|
||||
// ck_tile::HostTensor<DDataType> d_perm_host = shuffle_moe_weight(d_host, prec_w, 1);
|
||||
|
||||
// do moe sorting
|
||||
if(balance)
|
||||
@@ -345,8 +345,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
// done, preparing GPU buffer
|
||||
ck_tile::DeviceMem a_buf(a_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_perm_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_perm_host);
|
||||
ck_tile::DeviceMem g_perm_buf(g_host);
|
||||
ck_tile::DeviceMem d_perm_buf(d_host);
|
||||
ck_tile::DeviceMem sa_buf(sa_host);
|
||||
ck_tile::DeviceMem sg_buf(sg_host);
|
||||
ck_tile::DeviceMem sd_buf(sd_host);
|
||||
@@ -390,7 +390,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride};
|
||||
stride,
|
||||
max_num_tokens_padded};
|
||||
|
||||
float ave_time = fused_moegemm(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 75 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 90 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 124 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 18 KiB |
@@ -57,4 +57,76 @@ struct indexing_adaptor_onshot_cached
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
#define Using_Gather 1
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor
|
||||
{
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
|
||||
const IndexingType* cached_idx_;
|
||||
#if Using_Gather
|
||||
mutable index_t pre_up_index_ = 0;
|
||||
mutable index_t pre_low_index_ = 0;
|
||||
#endif
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
const UpIdx& idx_up) const
|
||||
{
|
||||
static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
|
||||
#if Using_Gather
|
||||
pre_up_index_ = idx_up[number<0>{}];
|
||||
pre_low_index_ = idx_low(number<0>{});
|
||||
#if 0
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& /*idx_low*/,
|
||||
const UpIdx& /*idx_up*/) const
|
||||
{
|
||||
// TODO: nonthing changed here
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
#if !Using_Gather
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
#else
|
||||
int up_index = idx_diff_up[number<0>{}] + pre_up_index_;
|
||||
int low_index = *(cached_idx_ + up_index);
|
||||
idx_diff_low(number<0>{}) = low_index - pre_low_index_;
|
||||
|
||||
pre_up_index_ = up_index;
|
||||
pre_low_index_ = low_index;
|
||||
#if 0
|
||||
if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)
|
||||
{
|
||||
printf("\n index form %d to %d, diff from %d to %d \n",
|
||||
up_index,
|
||||
low_index,
|
||||
idx_diff_up[number<0>{}],
|
||||
idx_diff_low(number<0>{}));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool is_known_at_compile_time()
|
||||
{
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -213,9 +213,9 @@ struct FusedMoeGemmGlKernel
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
|
||||
{
|
||||
constexpr index_t block_m = BlockShape::Block_M0;
|
||||
int max_num_tokens_padded =
|
||||
hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
|
||||
//constexpr index_t block_m = BlockShape::Block_M0;
|
||||
int max_num_tokens_padded = hargs.max_num_tokens_padded;
|
||||
//hargs.topk * hargs.num_tokens + hargs.num_experts * block_m - hargs.topk;
|
||||
// printf("xxx max_num_tokens_padded:%d\n", max_num_tokens_padded);
|
||||
return Partitioner::GridSize(max_num_tokens_padded, hargs.intermediate_size);
|
||||
}
|
||||
|
||||
@@ -117,6 +117,7 @@ struct FusedMoeGemmHostArgs
|
||||
index_t topk; // need this?
|
||||
|
||||
index_t stride_token; // for input/output, stride for each row, should >= hidden_size
|
||||
index_t max_num_tokens_padded; // size of sorted_token_ids_ptr
|
||||
};
|
||||
|
||||
// This is scatter/gather b2b group-gemm
|
||||
|
||||
Reference in New Issue
Block a user