diff --git a/example/65_gemm_multiply_multiply/moe_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp index 7eddc58846..0f778ed4d4 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1.cpp @@ -71,7 +71,7 @@ struct MulABScale (void)d2; // for gate, no d2 needed (void)d0; (void)d1; - const float x0_f = c; + const float x0_f = c * d1 * d0; // const float x0_f = c; e = ck::type_convert(x0_f); } @@ -286,9 +286,9 @@ int main(int argc, char* argv[]) case 1: a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{0, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d2_m_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{1, 3}); + d1_e_n.GenerateTensorValue(GeneratorTensor_2{1, 3}); + d2_m_n.GenerateTensorValue(GeneratorTensor_2{1, 3}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); @@ -304,6 +304,9 @@ int main(int argc, char* argv[]) d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d2_m_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } + d0_t_n.savetxt("d0_t_n.txt", "int"); + d1_e_n.savetxt("d1_e_n.txt", "int"); + d2_m_n.savetxt("d2_m_n.txt", "int"); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem a0_device_buf(sizeof(A0DataType) * a0_t_k.mDesc.GetElementSpaceSize()); @@ -325,8 +328,6 @@ int main(int argc, char* argv[]) auto b_element_op = BElementOp{}; auto cde_element_op = CDEElementOp{}; - constexpr auto I0 = ck::Number<0>{}; - // do GEMM auto device_op = DeviceOpInstance{}; @@ -352,7 +353,7 @@ int main(int argc, char* argv[]) K, StrideA, StrideB, - std::array{I0, I0, I0}, + StrideDs, StrideE, KBatch, a_element_op, @@ -406,9 +407,10 @@ int main(int argc, char* argv[]) { const int t = sorted_token_ids(m); + const int e = expert_ids(m / sorted_tile_size); for(int n = 0; n < N; ++n) { - cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(m, n), d2_m_n(m, n)); + cde_element_op(e_m_n_host_result(m, n), c_m_n(m, n), d0_t_n(t, n), d1_e_n(e, n), d2_m_n(m, n)); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp index 507b528894..5eda27f3d8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_gather.hpp @@ -1401,7 +1401,7 @@ struct GridwiseMoeGemmGather if (i.value == 1) { ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1); - // if ( threadIdx.x ==0) + // if ( threadIdx.x % 16 ==0) // printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]); } return make_dynamic_buffer( @@ -1448,10 +1448,11 @@ struct GridwiseMoeGemmGather StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; StaticallyIndexedArray scatter_weights; //= for topk // too hack here, 2 specific for topk weights, fixme - const float *p_sorted_weights = p_ds_grid[I2]; + const float *p_sorted_weights = p_ds_grid[I0]; static_for<0, EMRepeats, 1>{}([&](auto m0) { scatter_offsets(m0) = 0; - scatter_weights(m0) = p_sorted_weights[c_token_pos + m0]; + scatter_weights(m0) = p_sorted_weights[(c_token_pos + m0) * problem.StrideDs[0]]; + // if(threadIdx.x % 16 == 0) // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); }); auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index 5aaa06303b..b62c18538d 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -176,10 +176,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter src_coords_[i]); oob_val = oob_val & is_src_valid; - if (i.value == ScatterWeightIdx) + if (i.value == ScatterWeightIdx) { - static_assert(SrcScalarPerVectors{}[Number<2>{}] == 1, "scatter weight dim, should only one vec"); + static_assert(SrcScalarPerVectors{}[Number{}] == 1, "scatter weight dim, should only one vec"); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); + // if(threadIdx.x % 8 ==0 ) + // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number{})); static_for<0, SrcScalarPerVector, 1>{}( [&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights_(Number{}); }); } @@ -189,11 +191,15 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter using DataType = remove_cvref_t; const auto tmp = src_bufs[i].template Get(src_coords_[i].GetOffset(), true); + // if(threadIdx.x % 8 ==0 ) + // printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, i.value, src_coords_[i].GetOffset(), tmp); static_for<0, SrcScalarPerVector, 1>{}( [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); } else { + // if(threadIdx.x % 8 ==0 ) + // printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value); src_vectors(i).template AsType()(I0) = src_bufs[i].template Get(src_coords_[i].GetOffset(), true); }