From 7cca528aa7c2ef89875b255e698b1c04bd62f190 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Tue, 4 Mar 2025 14:24:22 +0000 Subject: [PATCH] fix per token quant --- .../65_gemm_multiply_multiply/moe_gemm1.cpp | 6 ++-- .../65_gemm_multiply_multiply/moe_gemm2.cpp | 16 +++++----- .../gpu/grid/gridwise_moe_gemm.hpp | 29 +++---------------- 3 files changed, 15 insertions(+), 36 deletions(-) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp index 1f22c28791..8263d68a71 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1.cpp @@ -224,7 +224,7 @@ int main(int argc, char* argv[]) ck::index_t StrideB = K; ck::index_t StrideE = N; constexpr ck::index_t NumDTensor = DsDataType::Size(); - constexpr auto StrideDs = std::array{0, 0}; + constexpr auto StrideDs = std::array{1, 0}; ck::index_t KBatch = 1; @@ -257,8 +257,8 @@ int main(int argc, char* argv[]) sorted_token_ids.mData[i] = tokens; } } - expert_ids.savetxt("expert_ids.txt", "int"); - sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + // expert_ids.savetxt("expert_ids.txt", "int"); + // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); Tensor a0_t_k(HostTensorDescriptor({tokens, K}, {K, 1})); Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); diff --git a/example/65_gemm_multiply_multiply/moe_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_gemm2.cpp index d1870d0068..bdc62a5b87 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2.cpp @@ -311,12 +311,12 @@ int main(int argc, char* argv[]) DeviceMem d1_device_buf(sizeof(D1DataType) * d1_e_n.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - a0_t_k_k.savetxt("a.txt"); - expert_ids.savetxt("expert_ids.txt", "int"); - sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); - d0_t_n.savetxt("d0_t_n.txt", "int"); - d1_e_n.savetxt("d1_e_n.txt", "int"); - d2_e_n.savetxt("d2_e_n.txt", "int"); + // a0_t_k_k.savetxt("a.txt"); + // expert_ids.savetxt("expert_ids.txt", "int"); + // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + // d0_t_n.savetxt("d0_t_n.txt", "int"); + // d1_e_n.savetxt("d1_e_n.txt", "int"); + // d2_e_n.savetxt("d2_e_n.txt", "int"); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -434,8 +434,8 @@ int main(int argc, char* argv[]) } e_device_buf.FromDevice(e_t_n_device_result.mData.data()); - e_t_n_device_result.savetxt("out.txt"); - e_t_n_host_result.savetxt("ref.txt"); + // e_t_n_device_result.savetxt("out.txt"); + // e_t_n_host_result.savetxt("ref.txt"); return ck::utils::check_err( e_t_n_device_result, e_t_n_host_result, "Error: Incorrect results!", 1e-3, 5e-2) ? 0 diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index b2337a7f9a..96b2e8e075 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -1492,7 +1492,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 1; auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -1576,7 +1576,7 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; + float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]]; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2000,7 +2000,7 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; // hack fix felix + constexpr index_t scatter_weight_idx = 1; auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), @@ -2084,7 +2084,7 @@ struct GridwiseMoeGemm static_for<0, EMRepeats, 1>{}([&](auto m0) { const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; + float weight = p_sorted_weights_0[token_offset * problem.StrideDs[0]]; if constexpr(IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -2139,27 +2139,6 @@ struct GridwiseMoeGemm }); } } - - // template - // __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, - // const index_t* p_sorted_expert_ids, - // const index_t* p_max_token_id, - // const ADataType* p_a_grid, - // const BDataType* p_b_grid, - // DsGridPointer& p_ds_grid, - // CDataType* p_c_grid, - // void* p_shared, - // void* p_shared1, - // const Problem& problem, - // AElementwiseOperation a_element_op, - // BElementwiseOperation b_element_op, - // CElementwiseOperation c_element_op) - // { - - // } }; } // namespace ck