diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp index 9ab946761c..138d946c26 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp @@ -292,9 +292,14 @@ struct FusedMoeGemmGlKernel number{}, number<1>{}); - const auto g_window_ = make_tile_window( + const auto g_view_1_ = pad_tensor_view( g_view_, make_tuple(number{}, number{}), + sequence{}); + + const auto g_window_ = make_tile_window( + g_view_1_, + make_tuple(number{}, number{}), {idx_n0, 0}); return g_window_; @@ -328,9 +333,14 @@ struct FusedMoeGemmGlKernel number{}, number<1>{}); - const auto d_window_ = make_tile_window( + const auto d_view_1_ = pad_tensor_view( d_view_, make_tuple(number{}, number{}), + sequence{}); + + const auto d_window_ = make_tile_window( + d_view_1_, + make_tuple(number{}, number{}), {0, idx_n0}); return d_window_; }(); diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp index 3232f80fcc..8f15cac8b3 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp @@ -391,7 +391,7 @@ struct FusedMoeGemmKernel number{}, number<1>{}); - // gather is here + // scatter is here auto o_scatter_view_ = transform_tensor_view( o_view_, make_tuple(make_indexing_transform(kargs.num_tokens, token_id), diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index 45e1340912..fc54d0ab56 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -71,9 +71,7 @@ struct FusedMoeGemmPipeline_General CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeA() { // matrix a or tokens smem - constexpr index_t smem_mat_a = - BlockShape::Block_M0 * BlockShape::Block_K0 * sizeof(ADataType); - return smem_mat_a; + return Policy::template GetSmemSize_A(); } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { @@ -131,11 +129,8 @@ struct FusedMoeGemmPipeline_General CK_TILE_LDS_ADDR void* smem, index_t hidden_size, index_t /*intermediate_size*/, - CWindow& c_window_) + CWindow& /*c_window_*/) { - ignore = c_window_; - ignore = hidden_size; - ignore = w_window_; CK_TILE_LDS_ADDR ADataType* smem_0 = reinterpret_cast(smem); CK_TILE_LDS_ADDR GDataType* smem_1 = reinterpret_cast( smem_0 + GetSmemSizeA() / sizeof(ADataType)); @@ -234,11 +229,11 @@ struct FusedMoeGemmPipeline_General #if 0 PrintMem(y_pre, "Y_pre", 0); #endif - if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) - { - block_sync_lds(); - store_tile(c_window_, y_pre); - } + // if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) + // { + // block_sync_lds(); + // store_tile(c_window_, y_pre); + // } // save to lds auto bridge_lds_view = make_tensor_view( smem_0, Policy::template MakeBridgeLdsBlockDesc()); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index 5040e192a9..83dda0c114 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -312,12 +312,6 @@ struct FusedMoeGemmPipelineGeneralPolicy make_tuple(sequence<1>{}, sequence<0, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); - // constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( - // make_tuple(number{}, number{}), - // make_tuple(number{}, number<1>{}), - // number<8>{}, - // number<1>{}); - return a_lds_block_desc; }