From 08e902e2fa730cb48d335953329250aeae96be50 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 5 Mar 2025 09:34:56 +0000 Subject: [PATCH] hot fix moe gemm2 --- include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index d0e06a6c53..5337fd5e2c 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1492,7 +1492,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 1; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -2000,7 +2000,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = 1; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),