No atomic passes!

This commit is contained in:
Ali Nouri
2025-09-26 22:35:44 +00:00
parent 2cba567e5b
commit fa8baf3ce6
2 changed files with 13 additions and 4 deletions

View File

@@ -5,6 +5,10 @@
#include "fused_moegemm_api_traits.hpp"
#include "ck_tile/ops/fused_moe.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner_no_atomic.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner_debug.hpp"
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_tile_partitioner_simple_no_atomic.hpp"
#include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_flatmm_policy_no_atomic.hpp"
#include <iostream>
template <ck_tile::index_t... Is>
@@ -14,7 +18,8 @@ using S = ck_tile::sequence<Is...>;
template <typename Ts_>
float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
{
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 1 /*atomic*/>;
// STEP 1: Test no-atomic with simple partitioner
using f_traits = ck_tile::FusedMoeGemmTraits<Ts_::GateOnly, Ts_::FusedQuant == 1, 0 /*NO atomic*/>;
using f_shape = ck_tile::FusedMoeGemmShape<typename Ts_::BlockTile_0,
typename Ts_::WarpPerBlock_0,
typename Ts_::WarpTile_0,
@@ -48,11 +53,14 @@ float fused_moegemm_(const ck_tile::stream_config& s, fused_moegemm_args a)
f_traits>;
// using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmEx<f_problem>;
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem>;
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_Linear<f_shape>;
// FINAL: Complete no-atomic implementation
using f_pipeline = ck_tile::FusedMoeGemmPipeline_FlatmmUk<f_problem, ck_tile::FusedMoeGemmPipelineFlatmmPolicy_NoAtomic>;
// FINAL: Production-ready no-atomic partitioner with proper validation
using f_partitioner = ck_tile::FusedMoeGemmTilePartitioner_NoAtomic<f_shape>;
using f_kernel = ck_tile::FusedMoeGemmKernel<f_partitioner, f_pipeline, void>;
const dim3 grids = dim3(1, 1, 1); //f_kernel::GridSize(a);
// CHANGE 3: Use proper grid calculation (no longer hardcoded)
const dim3 grids = f_kernel::GridSize(a);
const dim3 blocks = f_kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;

View File

@@ -382,6 +382,7 @@ struct FusedMoeGemmKernel
return d_window_;
}();
static_assert(false && "This code path should be hit");
auto o_window = [&]() {
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr);
auto o_view_ = make_naive_tensor_view<address_space_enum::global,