diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 604c918212..158ed3cd64 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -149,19 +149,19 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< 2, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; #else -static constexpr ck::index_t MPerBlock = 128; +static constexpr ck::index_t MPerBlock = 32; using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< Row, Col, DsLayout, ELayout, A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - MPerBlock, 128, 128, + 32, 128, 128, 16, 16, 32, 32, - 2, 2, + 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - 2, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>, + 1, 1, S<1, 8, 1, 32>, S<2, 1, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; #endif // clang-format on @@ -180,11 +180,11 @@ int main(int argc, char* argv[]) ck::index_t N = 6144; ck::index_t K = 4096; ck::index_t experts = 8; - ck::index_t sorted_tile_num = 19; - ck::index_t valid_tile_num = 16; + ck::index_t valid_tile_num = 2; + ck::index_t sorted_tile_num = valid_tile_num + 3; ck::index_t sorted_size = sorted_tile_num * MPerBlock; ck::index_t valid_size = valid_tile_num * MPerBlock; - ck::index_t tokens = 832; + ck::index_t tokens = 1; ck::index_t topk = 2; if(argc == 1) @@ -232,8 +232,9 @@ int main(int argc, char* argv[]) Tensor max_token_id(HostTensorDescriptor({1})); max_token_id.mData = {valid_size, 0, 1, 2, 3, 4, 5, 6, 7, 8}; + int eids[] = {0, 1, 3, 3, 3}; // int eids[] = {0, 1, 2, 3, 4, 5, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} - int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; + // int eids[] = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 3, 3, 3}; for(int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; @@ -269,7 +270,7 @@ int main(int argc, char* argv[]) Tensor b0_e_n_k(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor b1_e_n_k(HostTensorDescriptor( {experts, (K + Scale_Block_K - 1) / Scale_Block_K, (N + Scale_Block_N - 1) / Scale_Block_N}, - {(Scale_Stride_B * Scale_Stride_BN), Scale_Stride_BN, 1})); + {(Scale_Stride_B * Scale_Stride_BN), 1, Scale_Stride_BN})); Tensor b0_preshuffled(HostTensorDescriptor({experts, K, N}, {N * K, 1, K})); Tensor d2_e_n(HostTensorDescriptor({sorted_size, N}, {1, 0})); @@ -322,12 +323,12 @@ int main(int argc, char* argv[]) DeviceMem b1_device_buf(sizeof(B1DataType) * b1_e_n_k.mDesc.GetElementSpaceSize()); DeviceMem d2_device_buf(sizeof(D2DataType) * d2_e_n.mDesc.GetElementSpaceSize()); DeviceMem e_device_buf(sizeof(EDataType) * e_t_n_device_result.mDesc.GetElementSpaceSize()); - // a0_t_k_k.savetxt("a.txt"); - // expert_ids.savetxt("expert_ids.txt", "int"); - // sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); + a0_t_k_k.savetxt("a.txt"); + expert_ids.savetxt("expert_ids.txt", "int"); + sorted_token_ids.savetxt("sorted_token_ids.txt", "int"); // d0_t_n.savetxt("d0_t_n.txt", "int"); // d1_e_n.savetxt("d1_e_n.txt", "int"); - // d2_e_n.savetxt("d2_e_n.txt", "int"); + d2_e_n.savetxt("d2_e_n.txt", "int"); sorted_token_ids_dev.ToDevice(sorted_token_ids.mData.data()); expert_ids_dev.ToDevice(expert_ids.mData.data()); max_token_id_dev.ToDevice(max_token_id.mData.data()); @@ -381,6 +382,71 @@ int main(int argc, char* argv[]) "wrong! device_gemm with the specified compilation parameters does " "not support this GEMM problem"); } + +#if 1 + // printf the input tensor + // printf a tensor + printf("a0_t_k_k: \n"); + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + printf("topk: %d: ", tk); + for(int k = 0; k < K; ++k) + { + printf("%f ", ck::type_convert(a0_t_k_k(t, tk, k))); + } + printf("\n"); + } + } + + // printf a scale tensor + printf("a1_t_k_k: \n"); + for(int t = 0; t < tokens; ++t) + { + for(int tk = 0; tk < topk; ++tk) + { + printf("topk: %d: ", tk); + for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k) + { + printf("%f ", ck::type_convert(a1_t_k_k(t, tk, k))); + } + printf("\n"); + } + } + + // printf b tensor + // printf("b0_e_n_k: \n"); + // for (int e=0; e < experts; ++e) + // { + // for (int k=0; k < K; ++k) + // { + // printf("expert: %d: ", e); + // for (int n=0; n < N; ++n) + // { + // printf("%f ", ck::type_convert(b0_e_n_k(e, k, n))); + // } + // printf("\n"); + // } + // } + + // printf b scale tensor + printf("b1_e_n_k: \n"); + for(int e = 0; e < experts; ++e) + { + for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k) + { + printf("expert: %d: ", e); + for(int n = 0; n < (N + Scale_Block_N - 1) / Scale_Block_N; ++n) + { + printf("%f ", ck::type_convert(b1_e_n_k(e, k, n))); + } + printf("\n"); + } + } + +#endif + if(time_kernel) { // not result correct here because output buf not setzero diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index 6632fc09c8..cc2cad3860 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -1107,7 +1107,7 @@ struct GridwiseMoeGemmBlockScale } // check gridwise gemm pipeline -#if 1 +#if 0 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value); if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) @@ -1373,14 +1373,15 @@ struct GridwiseMoeGemmBlockScale // get each thread's offset in the scale tensor // A scale - const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockK + a_thread_offset; + const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM; if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens) return; StaticallyIndexedArray scale_gather_offsets; static_for<0, MXdlPerWave, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[token_scale_pos + m0]; - index_t token_offset = fused_token & 0xffffff; + const index_t fused_token = + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl + a_thread_offset]; + index_t token_offset = fused_token & 0xffffff; if constexpr(!IsInputGemm) { token_offset = token_offset * problem.TopK + (fused_token >> 24); @@ -1389,6 +1390,10 @@ struct GridwiseMoeGemmBlockScale token_offset * math::integer_divide_ceil(problem.K, ScaleBlockK); }); + // printf("blkid: %d, tid:%d, a_thread_offset: %d, scale_gather_offsets: %d\n", block_m_id, + // threadIdx.x, a_thread_offset, + // scale_gather_offsets(Number<0>{})); + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2_gather\n", + // blockIdx.y, + // threadIdx.x, + // dst_buf(Number<0>{})); + // move src coordinate back to slice origin (or not) if constexpr(SrcResetCoordinateAfterRun) {