mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
rewite save o code
This commit is contained in:
@@ -306,25 +306,23 @@ struct FusedMoeGemmPipeline_General
|
||||
make_tuple(number<32>{}, number<32>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeGlobalTileDistribution_O<Problem>());
|
||||
ignore = o_alds_win;
|
||||
|
||||
auto save_o = [&]() {
|
||||
if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0)
|
||||
{
|
||||
if(threadIdx.x < 64)
|
||||
{
|
||||
auto o0 = load_tile(o_olds_win);
|
||||
for(int step = 1; step < 4; step++)
|
||||
{
|
||||
auto o0 = load_tile(o_olds_win);
|
||||
constexpr index_t thread_buffer_size = decltype(o0)::get_thread_buffer_size();
|
||||
static_for<1, BlockShape::Repeat_K1, 1>{}([&](auto) {
|
||||
move_tile_window(o_olds_win, {32, 0});
|
||||
auto o1 = load_tile(o_olds_win);
|
||||
for(int i = 0; i < 16; i++)
|
||||
{
|
||||
static_for<0, thread_buffer_size, 1>{}([&](auto i) {
|
||||
o0.get_thread_buffer()(i) = type_convert<ODataType>(
|
||||
type_convert<float>(o0.get_thread_buffer()[i]) +
|
||||
type_convert<float>(o1.get_thread_buffer()[i]));
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
update_tile(o_window_, o0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,14 +216,15 @@ struct FusedMoeGemmPipelineGeneralPolicy
|
||||
typename S_::WarpTile_0>>;
|
||||
|
||||
constexpr auto warp_gemm = GetWarpGemm0<Problem>();
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
// using BlockGemmPolicy =
|
||||
// BlockGemmASmemBRegCRegV1CustomPolicy<typename
|
||||
// Problem::ADataType,
|
||||
typename Problem::GDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename S_::WarpPerBlock_0,
|
||||
decltype(warp_gemm)>;
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
// using BlockGemmPolicy =
|
||||
// BlockGemmASmemBRegCRegV1CustomPolicy<typename
|
||||
// Problem::ADataType,
|
||||
typename Problem::GDataType,
|
||||
typename Problem::AccDataType,
|
||||
typename S_::WarpPerBlock_0,
|
||||
decltype(warp_gemm)>;
|
||||
|
||||
return BlockGemmASmemBSmemCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
// return BlockGemmASmemBRegCRegV1<GemmProblem, BlockGemmPolicy>{};
|
||||
|
||||
Reference in New Issue
Block a user