diff --git a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp index 158ed3cd64..5f5eeb3d92 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm2_xdl_fp8_blockscale.cpp @@ -155,7 +155,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemmBlockScale< A0DataType, A1DataType, B0DataType, B1DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, Scale_Block_M, Scale_Block_N, Scale_Block_K, - 32, 128, 128, + MPerBlock, 128, 128, 16, 16, 32, 32, 1, 1, @@ -306,6 +306,34 @@ int main(int argc, char* argv[]) b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); d2_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; + case 4: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); // 1 + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 5: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); // 1 + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 6: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); // 1 + b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); // 1 + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; + case 7: + a0_t_k_k.GenerateTensorValue(GeneratorTensor_1{}); + a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_e_n_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + d2_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + break; default: a0_t_k_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); a1_t_k_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); @@ -383,7 +411,7 @@ int main(int argc, char* argv[]) "not support this GEMM problem"); } -#if 1 +#if 0 // printf the input tensor // printf a tensor printf("a0_t_k_k: \n"); @@ -394,7 +422,7 @@ int main(int argc, char* argv[]) printf("topk: %d: ", tk); for(int k = 0; k < K; ++k) { - printf("%f ", ck::type_convert(a0_t_k_k(t, tk, k))); + printf("%.1f ", ck::type_convert(a0_t_k_k(t, tk, k))); } printf("\n"); } @@ -409,26 +437,26 @@ int main(int argc, char* argv[]) printf("topk: %d: ", tk); for(int k = 0; k < (K + Scale_Block_K - 1) / Scale_Block_K; ++k) { - printf("%f ", ck::type_convert(a1_t_k_k(t, tk, k))); + printf("%.1f ", ck::type_convert(a1_t_k_k(t, tk, k))); } printf("\n"); } } // printf b tensor - // printf("b0_e_n_k: \n"); - // for (int e=0; e < experts; ++e) - // { - // for (int k=0; k < K; ++k) - // { - // printf("expert: %d: ", e); - // for (int n=0; n < N; ++n) - // { - // printf("%f ", ck::type_convert(b0_e_n_k(e, k, n))); - // } - // printf("\n"); - // } - // } + printf("b0_e_n_k: \n"); + for (int e=0; e < experts; ++e) + { + for (int k=0; k < K; ++k) + { + printf("expert: %d: ", e); + for (int n=0; n < N; ++n) + { + printf("%.1f ", ck::type_convert(b0_e_n_k(e, k, n))); + } + printf("\n"); + } + } // printf b scale tensor printf("b1_e_n_k: \n"); @@ -439,7 +467,7 @@ int main(int argc, char* argv[]) printf("expert: %d: ", e); for(int n = 0; n < (N + Scale_Block_N - 1) / Scale_Block_N; ++n) { - printf("%f ", ck::type_convert(b1_e_n_k(e, k, n))); + printf("%.1f ", ck::type_convert(b1_e_n_k(e, k, n))); } printf("\n"); } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp index 25d6491e85..031f9b6683 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_v1.hpp @@ -426,6 +426,12 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< }); }); + // printf("blockIdx.y = %d, blockIdx.x = %d, threadIdx.x = %d, c_scale_thread_buf = <%f>\n", + // blockIdx.y, + // blockIdx.x, + // threadIdx.x, + // c_scale_thread_buf[Number<0>{}]); + // Local prefill A1 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0); @@ -552,6 +558,7 @@ struct BlockwiseGemmXdlops_pipeline_moe_blockscale_bpreshuffle_v1< b_thread_vec.template AsType(), c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})); }); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp index cc2cad3860..5291e89cff 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm_blockscale.hpp @@ -568,6 +568,8 @@ struct GridwiseMoeGemmBlockScale Number{}); } + using DsGridDesc_M_N = remove_cvref_t; + struct Problem { __host__ __device__ Problem(index_t NumTokens_, @@ -1360,7 +1362,7 @@ struct GridwiseMoeGemmBlockScale constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( make_tuple(Number{}, Number{})); - // constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); + constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); auto a_thread_offset = get_thread_local_1d_id() % MPerXdl + (get_thread_local_1d_id() / 64) / NWaves * MPerXdl; @@ -1380,7 +1382,7 @@ struct GridwiseMoeGemmBlockScale StaticallyIndexedArray scale_gather_offsets; static_for<0, MXdlPerWave, 1>{}([&](auto m0) { const index_t fused_token = - p_sorted_token_ids[token_scale_pos + m0 * MPerXdl + a_thread_offset]; + p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset]; index_t token_offset = fused_token & 0xffffff; if constexpr(!IsInputGemm) { @@ -1462,29 +1464,67 @@ struct GridwiseMoeGemmBlockScale // shuffle C and write out { + // // print C + // printf("tid: %d, blkid: %d, " + // "c_thread_buf = <%1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, %1.f, + // %1.f, %1.f, %1.f, %1.f, %1.f, %1.f\n", get_thread_local_1d_id(), block_m_id, + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<0>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<1>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<2>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<3>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<4>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<5>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<6>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<7>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<8>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<9>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<10>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<11>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<12>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<13>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<14>{}], + // c_thread_buf.GetVectorTypeReference(Number<0>{}) .template + // AsType()[Number<15>{}]); + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + // transposed XDL // TODO: hacky, fix it! - constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = - blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(); - constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); - constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); - constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); - constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); - constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6); + constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -1493,24 +1533,24 @@ struct GridwiseMoeGemmBlockScale static_cast(p_shared), c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); - constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor( c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // M0 (MXdlPerWave) per shuffle M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + M2)), // M2 = MPerXdl make_freeze_transform(I0), make_unmerge_transform(make_tuple( Number{}, // N0 (NXdlPerWave) per shuffle N1, // N1 = NWave - N2))), // N2 = NPerXdl + N2, // N2 * N3 * N4 = NPerXdl + N3, + N4))), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + Sequence<>{}, Sequence<0, 2, 4>{}, Sequence<>{}, Sequence<1, 3, 5, 6, 7>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -1520,56 +1560,56 @@ struct GridwiseMoeGemmBlockScale const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; - const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + const auto m_thread_data_on_block_to_m0_m1_m2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), - make_tuple(Sequence<0>{})); - - const auto m_thread_data_on_block_idx = - m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( - make_multi_index(m_thread_data_on_block)); - - const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = - make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + const auto n_thread_data_on_block_idx = - n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex( make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3, + N2, + I1, + N4>, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, 7, 1, InMemoryDataOperationEnum::Set, 1, true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, make_multi_index(0, 0, m_thread_data_on_block_idx[I1], n_thread_data_on_block_idx[I1], m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), + n_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I3], + n_thread_data_on_block_idx[I4]), ck::tensor_operation::element_wise::PassThrough{}}; using EDataType = CDataType; @@ -1665,16 +1705,16 @@ struct GridwiseMoeGemmBlockScale p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, + SpaceFillingCurve, Sequence<0, 1, 2, 3, 4, 5, 6, 7>, Sequence>{}; + N2, + 1, + N4>>{}; constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); @@ -1721,10 +1761,10 @@ struct GridwiseMoeGemmBlockScale block_sync_lds(); // each thread write its data from VGPR to LDS - c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4, sfc_c_vgpr.GetIndexTupleOfNumber(access_id), c_thread_buf, - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4, c_shuffle_block_buf); // make sure it's safe to read from LDS