diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index 61954720b6..1108d8dd65 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -1,9 +1,19 @@ +list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -mllvm -greedy-reverse-local-assignment=1) +list(APPEND TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + +add_example_executable(example_moe_gemm1 moe_gemm1.cpp) +add_example_executable(example_moe_gemm2 moe_gemm2.cpp) +target_compile_options(example_moe_gemm1 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS}) +target_compile_options(example_moe_gemm2 PRIVATE ${TILE_EXAPMLE_BLOCKSCALE_COMPILE_OPTIONS}) + + add_example_executable(example_gemm_multiply_multiply_xdl_fp8 gemm_multiply_multiply_xdl_fp8.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_ab_scale gemm_multiply_multiply_xdl_fp8_ab_scale.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_fp8_bpreshuffle gemm_multiply_multiply_xdl_fp8_bpreshuffle.cpp) add_example_executable(example_gemm_add_add_xdl_fp16 gemm_add_add_xdl_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_xdl_int8 gemm_multiply_multiply_xdl_int8.cpp) -add_example_executable(example_moe_gemm1 moe_gemm1.cpp) -add_example_executable(example_moe_gemm2 moe_gemm2.cpp) add_example_executable(example_moe_pk_i4_gemm1 moe_pk_i4_gemm1.cpp) +set(EXAMPLE_COMPILE_OPTIONS) +list(APPEND EXAMPLE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker -g -fverbose-asm) +target_compile_options(example_moe_pk_i4_gemm1 PRIVATE ${EXAMPLE_COMPILE_OPTIONS}) add_example_executable(example_moe_pk_i4_gemm2 moe_pk_i4_gemm2.cpp) diff --git a/example/65_gemm_multiply_multiply/moe_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_gemm1.cpp index f61ddb2fe3..c8897e1a75 100644 --- a/example/65_gemm_multiply_multiply/moe_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_gemm1.cpp @@ -133,8 +133,8 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; -static constexpr ck::index_t NXDLPerWave = 2; +static constexpr ck::index_t MXDLPerWave = 4; +static constexpr ck::index_t NXDLPerWave = 1; static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; @@ -174,10 +174,11 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm // CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| // MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| // PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - MXDLPerWave, 1, S<1, 32, 1, 8>, S, + 4, 1, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; // kernel 2: 128->32x128x128 // < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, EDataType>; +// DeviceGemmMultiD_Xdl_CShuffle_V3_BPreshuffle< Row, Col, Tuple, Row, F8, F8, Tuple, F16, F32, F32, PassThrough, PassThrough, MultiplyMultiply, GemmSpec, 256, 128, 128, 128, 16, 16, 32, 32, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, // clang-format on @@ -197,8 +198,6 @@ int main(int argc, char* argv[]) ck::index_t experts = 8; ck::index_t sorted_tile_num = 16; ck::index_t valid_tile_num = 13; - ck::index_t sorted_size = sorted_tile_num * MPerBlock; - ck::index_t valid_size = valid_tile_num * MPerBlock; ck::index_t tokens = 544; ck::index_t topk = 2; @@ -217,6 +216,17 @@ int main(int argc, char* argv[]) K = std::stoi(argv[5]); tokens = std::stoi(argv[6]); } + else if(argc == 9) { + + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + N = std::stoi(argv[4]); + K = std::stoi(argv[5]); + tokens = std::stoi(argv[6]); + sorted_tile_num = std::stoi(argv[7]); + valid_tile_num = std::stoi(argv[8]); + } else { printf("arg1: verification (0=no, 1=yes)\n"); @@ -227,6 +237,8 @@ int main(int argc, char* argv[]) exit(0); } + ck::index_t sorted_size = sorted_tile_num * MPerBlock; + ck::index_t valid_size = valid_tile_num * MPerBlock; if (tokens * topk > valid_size) { printf("err config, tokens * topk > valid_size\n"); @@ -246,8 +258,10 @@ int main(int argc, char* argv[]) Tensor sorted_token_ids(HostTensorDescriptor({sorted_size}, {1})); Tensor max_token_id(HostTensorDescriptor({1 + sorted_tile_num})); // max_token_id.mData = {valid_size, 2, 2, 1, 1, 2, 2, 2,2, 2, 2, 2, 2,1,0,0,0}; + // max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; + // int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} max_token_id.mData = {valid_size, 0, 2, 3, 4, 6, 8, 10, 12, 13}; - int eids[] = {0, 0,1, 2,3, 3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} + int eids[] = {0, 0,1,2, 3,3, 4,4, 5, 5, 6, 6, 7, 3, 3, 3}; // {2, 1, 1, 2, 2, 2, 1, 2} for (int i = 0; i < sorted_tile_num; i++) { expert_ids.mData[i] = eids[i]; } @@ -287,8 +301,8 @@ int main(int argc, char* argv[]) case 1: a0_t_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - d0_t_n.GenerateTensorValue(GeneratorTensor_2{1, 2}); - d1_e_n.GenerateTensorValue(GeneratorTensor_2{1, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d1_e_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); break; case 2: a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); @@ -296,12 +310,21 @@ int main(int argc, char* argv[]) d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); break; + case 3: + a0_t_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_e_n_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + break; default: a0_t_k.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); b0_e_n_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); d0_t_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); d1_e_n.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); } + // d0_t_n.GenerateTensorValue(GeneratorTensor_1{}); + // d1_e_n.GenerateTensorValue(GeneratorTensor_1{}); + // b0_e_n_k.GenerateTensorValue(GeneratorTensor_1{}); DeviceMem sorted_token_ids_dev(sizeof(ck::index_t) * sorted_token_ids.mDesc.GetElementSpaceSize()); DeviceMem expert_ids_dev(sizeof(ck::index_t) * expert_ids.mDesc.GetElementSpaceSize()); DeviceMem max_token_id_dev(sizeof(ck::index_t) * max_token_id.mDesc.GetElementSpaceSize()); diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp index 2522edff56..7dc661bc3d 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm1.cpp @@ -64,7 +64,11 @@ struct MulABScale const float& d0, const float& d1) const { +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d1 * d0 * 16); +#else e = ck::type_convert(c * d1 * d0); +#endif } }; @@ -84,7 +88,11 @@ struct MulABScaleSilu { // act float x0 = 0; +#if CK_USE_PK4_LAYOUT_SHUFFLE + ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0 * 16); +#else ck::tensor_operation::element_wise::Silu{}(x0, c * d1 * d0); +#endif e = ck::type_convert(x0); } }; @@ -131,13 +139,13 @@ using BElementOp = PassThrough; static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; #if 0 -static constexpr ck::index_t MPerBlock = 128; -static constexpr ck::index_t MXDLPerWave = 2; +static constexpr ck::index_t MPerBlock = 64; +static constexpr ck::index_t MXDLPerWave = 1; static constexpr ck::index_t NXDLPerWave = 2; static constexpr ck::index_t BLOCKSIZE = 256; static constexpr ck::index_t NPerBlock = 128; static constexpr ck::index_t MNPerXDL = 32; -static constexpr ck::index_t KPerBlock = 256 / sizeof(A0DataType); +static constexpr ck::index_t KPerBlock = 64 / sizeof(A0DataType); static constexpr ck::index_t Nswizzle = false; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); @@ -154,8 +162,8 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< AK1, BK1, MNPerXDL, MNPerXDL, MXDLPerWave, NXDLPerWave, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, MXDLPerWave, 1, S<1, 32, 1, 8>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; // clang-format on @@ -167,13 +175,13 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm< Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, - 256, 128, 128, 64, + 256, MPerBlock, 128, 128, 16, 32, 32, 32, 4, 1, - S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, - S<2, 128, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, - 4, 1, S<1, 32, 1, 8>, S<4, 1, 1>, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 32, 32, 0, + 2, 1, S<1, 32, 1, 8>, S<4, 1, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, Nswizzle, true, A0DataType>; // clang-format on #endif @@ -359,6 +367,7 @@ int main(int argc, char* argv[]) } #endif +#if CK_USE_PK4_LAYOUT_SHUFFLE // vector pk_i4x4 permute for(int e = 0; e < experts; e++) { @@ -410,6 +419,7 @@ int main(int argc, char* argv[]) } } } +#endif b0_device_buf.ToDevice(b0_preshuffled.mData.data()); diff --git a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp index df2fcab6bb..3102cdf51b 100644 --- a/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp +++ b/example/65_gemm_multiply_multiply/moe_pk_i4_gemm2.cpp @@ -69,7 +69,12 @@ struct MulABScaleExpertWeight //for real kernel use //warning: hack hack hack here!!!! ignore d0 right now as kernel mul d0 * d2 outside. tofix:felix (void) d0; + +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d1 * d2 * 16); +#else e = ck::type_convert(c * d1 * d2); +#endif } // for reference cpu template <> @@ -81,7 +86,11 @@ struct MulABScaleExpertWeight const float& d2) const { // for reference cpu +#if CK_USE_PK4_LAYOUT_SHUFFLE + e = ck::type_convert(c * d0 * d1 * d2 * 16); +#else e = ck::type_convert(c * d0 * d1 * d2); +#endif } }; @@ -137,7 +146,7 @@ static constexpr ck::index_t KPerBlock = 128 / sizeof(A0DataType); static constexpr ck::index_t CShuffleNLane = 32; static constexpr ck::index_t CShuffleMLane = BLOCKSIZE / CShuffleNLane; static constexpr ck::index_t AK1 = 16 / sizeof(A0DataType); -static constexpr ck::index_t BK1 = 16 / sizeof(B0DataType); +static constexpr ck::index_t BK1 = 32 / sizeof(B0DataType); static constexpr ck::index_t EVec = 2; static constexpr ck::index_t D0Vec = 1; static constexpr ck::index_t D1Vec = 1; @@ -151,7 +160,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceMoeGemm MNPerXDL, MNPerXDL, MXDLPerWave, NXDLPerWave, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, - S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, AK1, AK1, 0, + S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, BK1, BK1, 0, MXDLPerWave, 1, S<1, CShuffleMLane, 1, CShuffleNLane>, S, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, false, false, A0DataType>; // clang-format on @@ -320,6 +329,7 @@ int main(int argc, char* argv[]) preShuffleBuffer(b0_e_n_k.mData.data(), b0_preshuffled.mData.data(), N * experts, K, device_op.GetPreShuffleParameters()); +#if CK_USE_PK4_LAYOUT_SHUFFLE // vector pk_i4x4 permute for(int e = 0; e < experts; e++) { @@ -371,6 +381,7 @@ int main(int argc, char* argv[]) } } } +#endif b0_device_buf.ToDevice(b0_preshuffled.mData.data()); @@ -443,8 +454,19 @@ int main(int argc, char* argv[]) auto ref_moe_gemm = ReferenceGemmInstance{}; auto ref_invoker = ref_moe_gemm.MakeInvoker(); - auto ref_argument = ref_moe_gemm.MakeArgument( - sorted_token_ids, expert_ids, max_token_id, MPerBlock, a0_t_k_k, b0_e_n_k, d0_t_n, d1_e_n, d2_e_n, c_t_n, PassThrough{}, PassThrough{}, cde_element_op); + auto ref_argument = ref_moe_gemm.MakeArgument(sorted_token_ids, + expert_ids, + max_token_id, + MPerBlock, + a0_t_k_k, + b0_e_n_k, + d0_t_n, + d1_e_n, + d2_e_n, + c_t_n, + PassThrough{}, + PassThrough{}, + cde_element_op); ref_invoker.Run(ref_argument); for(int t = 0; t < tokens; ++t) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp index 6e1f826dae..83534ec5af 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_dequant_v3.hpp @@ -194,17 +194,17 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_bdequant_v3{}([&](auto i_inst) { ignore = i_inst; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp index 86f03f038d..9781caf28c 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_v1.hpp @@ -190,7 +190,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_v1{}([&](auto i) { ignore = i; - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + __builtin_amdgcn_sched_group_barrier(0x008, 2, 0); // MFMA __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read }); diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp index 4bda56e48b..ac4c6f678e 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r3_scatter.hpp @@ -65,16 +65,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter const StaticallyIndexedArray& src_block_slice_origins, const DstDescs& dst_descs, const StaticallyIndexedArray& dst_block_slice_origins, - const ElementwiseOperation& element_op, - const StaticallyIndexedArray &scatter_offsets, - const StaticallyIndexedArray &scatter_weights) + const ElementwiseOperation& element_op) : threadwise_transfer_(src_descs, StaticallyIndexedArray{}, dst_descs, StaticallyIndexedArray{}, - element_op, - scatter_offsets, - scatter_weights) + element_op) { static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() && nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() && @@ -129,12 +125,13 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, + StaticallyIndexedArray &scatter_weights, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { - threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id); + threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id); } } @@ -144,15 +141,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter template __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, Number thread_scratch_id = Number{}) { if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize()) { if constexpr(is_detected::value) - threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id); + threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id); else - threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id); + threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id); } } @@ -160,10 +158,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter __device__ void Run(const SrcDescs& src_descs, const SrcBuffers& src_bufs, const DstDescs& dst_descs, - DstBuffers dst_bufs) + DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, + StaticallyIndexedArray &scatter_weights) { - RunRead(src_descs, src_bufs); - RunWrite(dst_descs, dst_bufs); + RunRead(src_descs, src_bufs, scatter_weights); + RunWrite(dst_descs, dst_bufs, scatter_offsets); } template diff --git a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp index 76deacb5ff..f6db7f5b6e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp @@ -247,6 +247,7 @@ struct DeviceMoeGemm // static_assert(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && // has_main_k_block_loop, "only impl BlockGemmPipelineVersion::v3 and has mainloop right // now"); + constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; if(has_main_k_block_loop) { // Tail number always full @@ -279,7 +280,6 @@ struct DeviceMoeGemm // } // else { - constexpr auto MemoryDataOp = IsInputGemm ? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd; // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) // { // const auto kernel = kernel_moe_gemm< @@ -304,8 +304,9 @@ struct DeviceMoeGemm } } } - // else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - // { + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { // if(arg.KBatch > 1) // { // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) @@ -332,31 +333,29 @@ struct DeviceMoeGemm // } // } // else - // { - // if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - // { - // const auto kernel = - // kernel_moe_gemm_gather_2lds< - // GridwiseGemm, - // true, - // IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd, - // minimum_occupancy, - // TailNumber::Odd>; - // RunKernel(kernel); - // } - // else - // { - // const auto kernel = - // kernel_moe_gemm_gather_2lds< - // GridwiseGemm, - // true, - // IsInputGemm? InMemoryDataOperationEnum::Set : InMemoryDataOperationEnum::AtomicAdd, - // minimum_occupancy, - // TailNumber::Even>; - // RunKernel(kernel); - // } - // } - // } + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + else + { + const auto kernel = kernel_moe_gemm_2lds; + RunKernel(kernel); + } + } + } else { throw std::runtime_error("todo: only v1 & v2 support now"); diff --git a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp index 17ea38aa0d..037a026eca 100644 --- a/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp +++ b/include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp @@ -79,13 +79,30 @@ __device__ inline half4_t i4_to_half4_scale(int q, const ck::half2_t& scale) return res.template AsType()[Number<0>{}]; } +__device__ inline f8x4_t i4_to_f8x4(int q) +{ + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + + int lo = amd_assembly_and_b32(q, LO); + int hi = amd_assembly_and_b32(q, HI); + + float f32_0 = amd_assemble_cvt_f32_i4(lo); + float f32_1 = amd_assemble_cvt_f32_i4(lo >> 16); + float f32_2 = amd_assemble_cvt_f32_i4(hi); + float f32_3 = amd_assemble_cvt_f32_i4(hi >> 16); + + // vector_type res; + // res.template AsType()(Number<0>{}) = amd_assemble_cvt_f8_f32(f32_1st, f32_2nd, f32_3rd, f32_4th); + return amd_assembly_cvt_f8_to_f32(f32_0, f32_1, f32_2, f32_3); +} + __device__ inline f8x8_t i4_to_fp8x8(int q) { - vector_type res; - - res.template AsType()(Number<0>{}) = amd_assembly_i4_to_fp8x2(q); - - return res.template AsType()[Number<0>{}]; + // f8x8_t res; + // amd_assembly_i4_to_fp8x8(res, q); + // return res; + return amd_assembly_i4_to_fp8x8(q); } __device__ inline bhalf4_t i4_to_bhalf4(int q) @@ -154,13 +171,55 @@ struct PassThroughPack8 __host__ __device__ constexpr void operator()(ck::f8x8_t& y, const ck::pk_i4x4_t& x) const { #if CK_USE_PK4_LAYOUT_SHUFFLE - vector_type result; + y = i4_to_fp8x8(bit_cast(x)); - result.template AsType()(Number<0>{}) = i4_to_fp8x8(bit_cast(x)); + // vector_type result; - y = result.template AsType()[Number<0>{}]; + // result.template AsType()(Number<0>{}) = i4_to_f8x4(bit_cast(x)); + // result.template AsType()(Number<1>{}) = i4_to_f8x4(bit_cast(x) >> 8); + + // y = result.template AsType()[Number<0>{}]; #else // Added pk_i4_t to f8x2_fnuz_t conversion + vector_type dst; + vector_type dst_tmp; + vector_type src{x}; + + // pk_i4_t to float2_t conversion + dst_tmp.template AsType()(Number<0>{}) = + type_convert(src.template AsType()[Number<0>{}]); + + dst_tmp.template AsType()(Number<1>{}) = + type_convert(src.template AsType()[Number<1>{}]); + + dst_tmp.template AsType()(Number<2>{}) = + type_convert(src.template AsType()[Number<2>{}]); + + dst_tmp.template AsType()(Number<3>{}) = + type_convert(src.template AsType()[Number<3>{}]); + + // float to f8_t conversion + dst.template AsType()(Number<0>{}) = + type_convert(dst_tmp.template AsType()[Number<0>{}]); + dst.template AsType()(Number<1>{}) = + type_convert(dst_tmp.template AsType()[Number<1>{}]); + + dst.template AsType()(Number<2>{}) = + type_convert(dst_tmp.template AsType()[Number<2>{}]); + dst.template AsType()(Number<3>{}) = + type_convert(dst_tmp.template AsType()[Number<3>{}]); + + dst.template AsType()(Number<4>{}) = + type_convert(dst_tmp.template AsType()[Number<4>{}]); + dst.template AsType()(Number<5>{}) = + type_convert(dst_tmp.template AsType()[Number<5>{}]); + + dst.template AsType()(Number<6>{}) = + type_convert(dst_tmp.template AsType()[Number<6>{}]); + dst.template AsType()(Number<7>{}) = + type_convert(dst_tmp.template AsType()[Number<7>{}]); + + y = dst.template AsType()[Number<0>{}]; #endif } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp index 86b670997d..854741e039 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_moe_gemm.hpp @@ -62,39 +62,43 @@ __global__ void #endif // end of if (defined(__gfx9__)) } -// template -// __global__ void -// #if CK_USE_LAUNCH_BOUNDS -// __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -// #endif -// // __attribute__((amdgpu_waves_per_eu(1, 1))) -// kernel_moe_gemm_gather_2lds(typename GridwiseGemm::Argument karg) -// { -// #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) -// __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; -// __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + // __attribute__((amdgpu_waves_per_eu(1, 1))) + kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg) +{ +#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; -// auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); -// GridwiseGemm::template Run_2Lds( -// karg.p_a_grid + splitk_batch_offset.a_k_split_offset, -// karg.p_b_grid + splitk_batch_offset.b_k_split_offset, -// karg.p_ds_grid, -// karg.p_c_grid, -// p_shared, -// p_shared1, -// karg, -// karg.a_element_op, -// karg.b_element_op, -// karg.c_element_op); -// #else -// ignore = karg; -// #endif // end of if (defined(__gfx9__)) -// } + GridwiseGemm::template Run_2Lds( + karg.p_sorted_token_ids, + karg.p_sorted_expert_ids, + karg.p_max_token_id, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_ds_grid, + karg.p_c_grid, + p_shared, + p_shared1, + karg, + karg.a_element_op, + karg.b_element_op, + karg.c_element_op); +#else + ignore = karg; +#endif // end of if (defined(__gfx9__)) +} template ::value) { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(LDSTypeA) /APackedSize < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(LDSTypeA); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple( - AK0Number * Number{}, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + constexpr auto a_lds_block_desc = + make_naive_tensor_descriptor(make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( a_lds_block_desc, - make_tuple(make_xor_with_modulo_transform(make_tuple( - Number{}, Number{})), + make_tuple(make_xor_with_modulo_transform( + make_tuple(Number{}, Number{})), make_pass_through_transform(AK1Number)), make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{})); - constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(AK0Number, Number{})), - make_pass_through_transform(Number{}), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}, Sequence<3>{})); - - constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_ak0_mldslayer_m_ak1, - make_tuple(make_pass_through_transform(AK0Number), - make_merge_transform_v3_division_mod( - make_tuple(Number{}, Number{})), - make_pass_through_transform(AK1Number)), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); - - return a_lds_block_desc_ak0_m_ak1; + return a_lds_block_desc_permuted; } else // ColumnMajor A { @@ -1508,32 +1491,6 @@ struct GridwiseMoeGemm using CDEBlockTransferCluster = CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; - constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); - constexpr auto EMRepeats = MPerBlock / EMThreads; - constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); - const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats; - StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; - StaticallyIndexedArray scatter_weights; //= for topk - // too hack here, 2 specific for topk weights, fixme - const float *p_sorted_weights_0 = p_ds_grid[I0]; - // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; - - static_for<0, EMRepeats, 1>{}([&](auto m0) { - const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; - index_t token_offset = fused_token & 0xffffff; - float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; - if constexpr (IsInputGemm) - { - token_offset = token_offset * problem.TopK + (fused_token >> 24); - } else { - const float *p_sorted_weights_2 = p_ds_grid[I2]; - weight = weight * p_sorted_weights_2[c_token_pos + m0]; - } - scatter_offsets(m0) = token_offset * problem.N; - scatter_weights(m0) = weight; - // if(threadIdx.x % 16 == 0) - // printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0)); - }); constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< ThisThreadBlock, @@ -1569,9 +1526,7 @@ struct GridwiseMoeGemm idx_c_ds_block_begin, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), make_tuple(make_multi_index(0, 0, block_n_id, 0)), - c_element_op, - scatter_offsets, - scatter_weights}; + c_element_op}; auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); @@ -1600,8 +1555,37 @@ struct GridwiseMoeGemm CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + const float *p_sorted_weights_0 = p_ds_grid[I0]; static_for<0, num_access, 1>{}([&](auto access_id) { // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_weights; //= for topk + // too hack here, 2 specific for topk weights, fixme + // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; + if constexpr (IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } else { + const float *p_sorted_weights_2 = p_ds_grid[I2]; + weight = weight * p_sorted_weights_2[c_token_pos + m0]; + } + + // if(threadIdx.x % 8 == 0 && blockIdx.x == 0) + // printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight); + scatter_offsets(m0) = token_offset * problem.N; + scatter_weights(m0) = weight; + }); + block_sync_lds(); // each thread write its data from VGPR to LDS @@ -1619,7 +1603,522 @@ struct GridwiseMoeGemm c_ds_desc_refs, c_ds_buf_refs, tie(e_grid_desc_mblock_mperblock_nblock_nperblock), - tie(c_grid_buf)); + tie(c_grid_buf), + scatter_offsets, + scatter_weights + ); + + if constexpr(access_id < num_access - 1) + { + constexpr auto cde_lds_and_global_step = + sfc_cde_block.GetForwardStep(access_id); + + // move on Ds + static_for<0, NumDTensor, 1>{}([&](auto i) { + cde_block_copy_lds_and_global.MoveSrcSliceWindow( + c_ds_desc_refs, i + I1, cde_lds_and_global_step); + }); + + // move on E + cde_block_copy_lds_and_global.MoveDstSliceWindow( + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + I0, + cde_lds_and_global_step); + } + }); + } + } + + template + __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + const index_t* p_sorted_expert_ids, + const index_t* p_max_token_id, + const ADataType* p_a_grid, + const BDataType* p_b_grid, + DsGridPointer& p_ds_grid, + CDataType* p_c_grid, + void* p_shared, + void* p_shared1, + const Problem& problem, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op) + { + ignore = b_element_op; + const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1( + IsInputGemm? problem.NumTokens : problem.NumTokens * problem.TopK, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0); + const auto b_grid_desc_bpreshuffled = + MakeBGridDescriptor_Preshuffled(problem.BN0Shuffled, problem.BK0Shuffled); + const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( + IsInputGemm? problem.NumTokens * problem.TopK : problem.NumTokens , problem.MPadded, problem.N, problem.NPadded, problem.StrideC); + // printf("tido %d size %d %d MNBLOCK %d %d %d %d\n", threadIdx.x, problem.StrideC, c_grid_desc_m_n.GetElementSpaceSize(), + // problem.MBlock, problem.NBlock, MPerBlock, NPerBlock); + const auto c_grid_desc_mblock_mperblock_nblock_nperblock = + MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + c_grid_desc_m_n, problem.MBlock, problem.NBlock); + const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]); + // constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2}; + // const index_t b_block_id = blockIdx.x % problem.NBlock; + const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y; + if (expert_block_id * MPerBlock >= max_token_id) + return; + const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]); + const auto block_mn = [&]() -> std::pair { + if constexpr (NSwizzle) + { + // const index_t expert_block_id = blockIdx.x / problem.NBlock; // + // const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]); + // const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1]; + // const index_t expert_block_swizzle = expert_block_id / expert_swizzle; + // const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle); + // const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8); + // const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle); + // if(threadIdx.x==0) + // printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, es, p_sorted_expert_ids[expert_block_id]); + + const index_t ecnt_prefix = p_max_token_id[1+expert_id]; + const index_t prefix_block = ecnt_prefix * problem.NBlock; + const index_t ecnt = p_max_token_id[2+expert_id] - ecnt_prefix; + const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; //p_max_token_id[expert_id + 1]; // 2 + const index_t bid_new = blockIdx.x - prefix_block; + const index_t nid = __builtin_amdgcn_readfirstlane(bid_new % 8 + bid_new / (8 * expert_swizzle) * 8); + const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle); + // if(threadIdx.x==0) + // printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, ecnt, expert_id); + return {nid, mid}; + } else { + return {blockIdx.x, blockIdx.y}; + } + }(); + const index_t block_n_id = block_mn.first; + const index_t block_m_id = block_mn.second; + + // if (threadIdx.x==0) { + // printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id); + // } + const index_t token0 = __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff); + + // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1); + constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0); + constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2); + constexpr auto AKThreads = AK0Threads * AK1Threads; + constexpr auto AMRepeats = MPerBlock / AMThreads; + const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats; + + if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens) + return; + StaticallyIndexedArray gather_offsets; //= p_sorted_token_ids[token_pos]; + static_for<0, AMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + if constexpr (!IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } + gather_offsets(m0) = token_offset * problem.K; + // printf("init off tid %d m %d off %d\n", threadIdx.x, m0(), gather_offsets(m0)); + }); + const index_t expert_stride = __builtin_amdgcn_readfirstlane(problem.N * problem.K); + + // N0, K0, Blocksize*KPack + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave); + + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid + expert_id * expert_stride / BPackedSize, b_grid_desc_bpreshuffled.GetElementSpaceSize()); + // if(threadIdx.x==0) + // printf("tid %d eid %d expert_stride %d bufsize %d\n", + // threadIdx.x, expert_id, expert_stride, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1(); + + // B matrix in LDS memory, dst of blockwise copy + // dummy + constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); + // A matrix blockwise copy + auto a_blockwise_copy = + ThreadGroupTensorSliceTransfer_v4r1_mod8, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + LDSTypeA, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + Sequence<0, 1, 2>, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true, + 1, + 2>( + a_grid_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + a_element_op, + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0), + ck::tensor_operation::element_wise::PassThrough{}, + gather_offsets); + + // Thread-wise copy + // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack + auto b_block_buf_ping = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_buf_pong = make_static_buffer( + b_block_desc_bk0_n_bk1.GetElementSpaceSize()); + auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong); + + auto b_blockwise_copy = ThreadwiseTensorSliceTransfer_v2< + BDataType, + BDataType, + decltype(b_grid_desc_bpreshuffled), + decltype(b_block_desc_bk0_n_bk1), + Sequence{}, I1, Number{}, Number{}>, + Sequence<1, 2, 0, 3>, + 3, + BBlockTransferSrcScalarPerVector, + BThreadTransferSrcResetCoordinateAfterRun, + true>(b_grid_desc_bpreshuffled, + make_multi_index(n_block_data_idx_on_grid, + get_warp_local_1d_id() % NWave, + 0, + KPack * (get_thread_local_1d_id() % warpSize))); + + // LDS allocation for A and B: be careful of alignment + // Cast after lds + auto a_block_buf_ping = make_dynamic_buffer( + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_buf_pong = make_dynamic_buffer( + static_cast(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong); + + constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, KRepeat, 0); + + // Blockwise GEMM pipeline + static_assert(std::is_default_constructible_v); + auto blockwise_gemm_pipeline = BlockwiseGemmPipe{}; + auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); + + const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( + (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / + KPerBlock); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bpreshuffled, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + num_k_block_main_loop); + + // shuffle C and write out + { + static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && + NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, + "wrong!"); + + constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); + + // TODO: hacky, fix it! + constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 = + blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + // TODO: hacky, fix it! + // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + + constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); + constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); + constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); + constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); + constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + + constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = + GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); + + auto c_shuffle_block_buf = make_dynamic_buffer( + static_cast(p_shared), + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + + constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor( + c_shuffle_block_desc_mblock_mperblock_nblock_nperblock, + make_tuple( + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // M0 (MXdlPerWave) per shuffle + M1, // M1 = MWave + M2, // M2 * M3 * M4 = MPerXdl + M3, + M4)), + make_freeze_transform(I0), + make_unmerge_transform(make_tuple( + Number{}, // N0 (NXdlPerWave) per shuffle + N1, // N1 = NWave + N2))), // N2 = NPerXdl + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple( + Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0]; + const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1]; + + const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_block_idx = + m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_block)); + + const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_block_idx = + n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_block)); + + // shuffle: threadwise copy C from VGPR to LDS + auto c_thread_copy_vgpr_to_lds = + ThreadwiseTensorSliceTransfer_v1r3, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + 7, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{ + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + n_thread_data_on_block_idx[I2]), + ck::tensor_operation::element_wise::PassThrough{}}; + + using EDataType = CDataType; + + const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N( + problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs); + + const auto ds_grid_desc_mblock_mperblock_nblock_nperblock = + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n, problem.MBlock, problem.NBlock); + + const auto ds_grid_buf = generate_tuple( + [&](auto i) { + using DDataType = remove_cvref_t>; + const DDataType *ptr_ = p_ds_grid[i]; + // hack logic here to support different kind of strides. todo fix it. + // ascale t, 1; bscale E, N, 1, move ptr to E + if (i.value == 1) + { + ptr_ += expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1); + // if ( threadIdx.x % 16 ==0) + // printf("bid %d eid %d b eoff %d %f\n", blockIdx.y, expert_id, expert_id * (problem.StrideDs[1]? problem.StrideDs[1] * problem.N : 1), ptr_[0]); + } + return make_dynamic_buffer( + ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize()); + }, + Number{}); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_desc_refs = concat_tuple_of_reference( + tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; }, + Number{})); + + // tuple of reference to C/Ds tensor descriptors + const auto c_ds_buf_refs = concat_tuple_of_reference( + tie(c_shuffle_block_buf), + generate_tie( + [&](auto i) -> const auto& // return type should be reference + { return ds_grid_buf[i]; }, + Number{})); + + // tuple of starting index of C/Ds blockwise copy + const auto idx_c_ds_block_begin = container_concat( + make_tuple(make_multi_index(0, 0, 0, 0)), + generate_tuple( + [&](auto) { + return make_multi_index(block_m_id, 0, block_n_id, 0); + // return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0); + }, + Number{})); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + c_grid_desc_mblock_mperblock_nblock_nperblock; + + using CDEBlockTransferCluster = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock; + const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation; + constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix + auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter< + ThisThreadBlock, + decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})), + Tuple, + decltype(c_ds_desc_refs), + decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)), + CElementwiseOperation, + Sequence(EGlobalMemoryDataOperation)>, // FIXME: make Sequence + // support arbitray type + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths, + CDEBlockTransferCluster, + Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder, + Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder, + Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder, + 3, // index_t SrcVectorDim, + 3, // index_t DstVectorDim, + CDEShuffleBlockTransferScalarPerVectors, + CShuffleBlockTransferScalarPerVector_NPerBlock, + sequence_merge_t< + Sequence, + uniform_sequence_gen_t>, // ThreadTransferSrcResetCoordinateAfterRunFlags + Sequence, // ThreadTransferDstResetCoordinateAfterRunFlags + 1, //ScatterDim + true, //OutputScatter: false, only use scatter weights + scatter_weight_idx // ScatterWeightIdx: ascale + > + {c_ds_desc_refs, + idx_c_ds_block_begin, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + make_tuple(make_multi_index(0, 0, block_n_id, 0)), + c_element_op}; + + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // space filling curve for threadwise C in VGPR + constexpr auto sfc_c_vgpr = + SpaceFillingCurve, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence>{}; + + constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess(); + + // space filling curve for shuffled blockwise C/D/E + constexpr auto sfc_cde_block = + SpaceFillingCurve, + Sequence<0, 2, 1, 3>, + Sequence<1, + CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl, + 1, + CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{}; + + static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!"); + constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1); + constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads; + constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3); + const float *p_sorted_weights_0 = p_ds_grid[I0]; + static_for<0, num_access, 1>{}([&](auto access_id) { + // make sure it's safe to write to LDS + StaticallyIndexedArray scatter_offsets; //= p_sorted_token_ids[c_token_pos]; + StaticallyIndexedArray scatter_weights; //= for topk + // too hack here, 2 specific for topk weights, fixme + // const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24; + + auto dstidx = sfc_cde_block.GetIndex(access_id); + const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1); + static_for<0, EMRepeats, 1>{}([&](auto m0) { + const index_t fused_token = p_sorted_token_ids[c_token_pos + m0]; + index_t token_offset = fused_token & 0xffffff; + float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]]; + if constexpr (IsInputGemm) + { + token_offset = token_offset * problem.TopK + (fused_token >> 24); + } else { + const float *p_sorted_weights_2 = p_ds_grid[I2]; + weight = weight * p_sorted_weights_2[c_token_pos + m0]; + } + + // if(threadIdx.x % 8 == 0 && blockIdx.x == 0) + // printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight); + scatter_offsets(m0) = token_offset * problem.N; + scatter_weights(m0) = weight; + }); + + block_sync_lds(); + + // each thread write its data from VGPR to LDS + c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2, + sfc_c_vgpr.GetIndexTupleOfNumber(access_id), + c_thread_buf, + c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + c_shuffle_block_buf); + + // make sure it's safe to read from LDS + block_sync_lds(); + + // each block copy its data from LDS to global + cde_block_copy_lds_and_global.Run( + c_ds_desc_refs, + c_ds_buf_refs, + tie(e_grid_desc_mblock_mperblock_nblock_nperblock), + tie(c_grid_buf), + scatter_offsets, + scatter_weights + ); if constexpr(access_id < num_access - 1) { @@ -1644,8 +2143,12 @@ struct GridwiseMoeGemm // template - // __device__ static void Run_2Lds(const ADataType* p_a_grid, + // __device__ static void Run_2Lds(const index_t* p_sorted_token_ids, + // const index_t* p_sorted_expert_ids, + // const index_t* p_max_token_id, + // const ADataType* p_a_grid, // const BDataType* p_b_grid, // DsGridPointer& p_ds_grid, // CDataType* p_c_grid, @@ -1656,37 +2159,7 @@ struct GridwiseMoeGemm // BElementwiseOperation b_element_op, // CElementwiseOperation c_element_op) // { - // // const auto block_2_ctile_map = Block2CTileMapDefault{problem.M, problem.N, 4}; - // // Run_2Lds( - // // p_a_grid, - // // p_b_grid, - // // p_ds_grid, - // // p_c_grid, - // // p_shared, - // // p_shared1, - // // problem, - // // a_element_op, - // // b_element_op, - // // c_element_op, - // // block_2_ctile_map); - // } - // template - // __device__ static void Run_2Lds(const ADataType* p_a_grid, - // const BDataType* p_b_grid, - // DsGridPointer& p_ds_grid, - // CDataType* p_c_grid, - // void* p_shared, - // void* p_shared1, - // const Problem& problem, - // AElementwiseOperation a_element_op, - // BElementwiseOperation b_element_op, - // CElementwiseOperation c_element_op, - // const Block2CTileMap& block_2_ctile_map) - // { // } }; diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index b62c18538d..fb1ea640ff 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -100,14 +100,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter const StaticallyIndexedArray& src_slice_origins, const DstDescs& dst_descs, const StaticallyIndexedArray& dst_slice_origins, - const ElementwiseOperation& element_op, - const StaticallyIndexedArray &scatter_offsets, - const StaticallyIndexedArray &scatter_weights) + const ElementwiseOperation& element_op) : src_coords_(MakeCoordinates(src_descs, src_slice_origins)), dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)), - element_op_(element_op), - scatter_offsets_(scatter_offsets), - scatter_weights_(scatter_weights) + element_op_(element_op) { static_assert(SliceLengths::At(Number{}) % SrcScalarPerVector == 0, "wrong! cannot evenly divide"); @@ -158,6 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunRead(const SrcDescs& src_descs, const SrcBuffers& src_bufs, + StaticallyIndexedArray &scatter_weights, Number thread_scratch_id = Number{}) { // loop over space-filling curve @@ -181,9 +178,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static_assert(SrcScalarPerVectors{}[Number{}] == 1, "scatter weight dim, should only one vec"); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number{})); + // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number{})); static_for<0, SrcScalarPerVector, 1>{}( - [&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights_(Number{}); }); + [&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights(Number{}); }); } else if constexpr(SrcScalarPerVectors{}[i] == 1) { @@ -418,6 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter enable_if_t = false> __device__ void RunWrite(const DstDescs& dst_descs, DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, Number thread_scratch_id = Number{}) { OOBCheck(thread_scratch_id); @@ -430,13 +428,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter if constexpr (OutputScatter) { constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number{}); - scatter_offset = scatter_offsets_(Number{}); + scatter_offset = scatter_offsets(Number{}); } // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { using dst_vector_t = typename remove_cvref_t::type; auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); - const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord + const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], // dst_coords_[i]); @@ -449,11 +447,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter dst_offset, is_dst_valid, dst_vectors[i].template AsType()[I0]); - // if(1) { - // static_for<0, DstScalarPerVector, 1>{}([&](auto idx) { + // if(threadIdx.x%8 ==0 && blockIdx.x==0) { + // static_for<0, 1, 1>{}([&](auto idx) { // using DstData = remove_cvref_t>; // using print_vec_t = typename vector_type::type; - // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid, + // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid, // type_convert(dst_vectors[i].template AsType()[idx])); // }); // } @@ -509,10 +507,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter __device__ void Run(const SrcDescs& src_descs, const SrcBuffers& src_bufs, const DstDescs& dst_descs, - DstBuffers dst_bufs) + DstBuffers dst_bufs, + StaticallyIndexedArray &scatter_offsets, + StaticallyIndexedArray &scatter_weights) { - RunRead(src_descs, src_bufs); - RunWrite(dst_descs, dst_bufs); + RunRead(src_descs, src_bufs, scatter_weights); + RunWrite(dst_descs, dst_bufs, scatter_offsets); } __device__ static constexpr auto GetSrcCoordinateResetStep() @@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter ? dst_slice_origin_step_idx : dst_slice_origin_step_idx + GetDstCoordinateResetStep(); + auto adjusted_step_idx_scatter = [&]() + { + Index step_; + static_for<0, nDim, 1>{}([&](auto i) { + step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number{}]; + }); + + return step_; + } + (); // is it OK to construct a new step every time? - const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx); + const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter); move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step); } @@ -709,8 +719,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter SrcCoords src_coords_; DstCoords dst_coords_; const ElementwiseOperation element_op_; - StaticallyIndexedArray scatter_offsets_; - StaticallyIndexedArray scatter_weights_; }; } // namespace ck diff --git a/include/ck/utility/amd_inline_asm.hpp b/include/ck/utility/amd_inline_asm.hpp index ba27d63dbd..de59f200f0 100644 --- a/include/ck/utility/amd_inline_asm.hpp +++ b/include/ck/utility/amd_inline_asm.hpp @@ -11,6 +11,13 @@ namespace ck { +inline __device__ int amd_assembly_and_b32(int a, int b) +{ + int c; + asm volatile("v_and_b32 %0, %1, %2" : "=v"(c) : "v"(a), "v"(b)); + return c; +} + inline __device__ int amd_assembly_and_or_b32(int a, int b, int d) { int c; @@ -32,7 +39,24 @@ inline __device__ half2_t amd_assembly_pk_add_f16(half2_t a, half2_t b) return c; } -inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a) +inline __device__ float amd_assemble_cvt_f32_i4(int b) +{ + float a; + asm volatile("v_cvt_off_f32_i4 %0, %1" : "=v"(a) : "v"(b)); + return a; +} + +inline __device__ f8x4_t amd_assembly_cvt_f8_to_f32(float b0, float b1, float b2, float b3) +{ + f8x4_t a; + asm volatile("v_cvt_pk_fp8_f32 %0, %1, %2\n" + "v_cvt_pk_fp8_f32 %0, %3, %4, op_sel:[0, 0, 1]\n" + : "=v"(a) + : "v"(b0), "v"(b1), "v"(b2), "v"(b3)); + return a; +} + +inline __device__ f8x8_t amd_assembly_i4_to_fp8x8(int a) { uint32_t i4x8 = static_cast(a); uint32_t fp8x4_0; @@ -60,14 +84,7 @@ inline __device__ f8x8_t amd_assembly_i4_to_fp8x2(int a) [v_src] "+v"(i4x8) :); - union - { - uint64_t as_uint64; - f8x8_t as_f8x8; - } convert; - - convert.as_uint64 = (static_cast(fp8x4_1) << 32) | fp8x4_0; - return convert.as_f8x8; + return bit_cast(((static_cast(fp8x4_1) << 32) | fp8x4_0)); } // c0 += inner_product(a, b0) diff --git a/include/ck/utility/data_type.hpp b/include/ck/utility/data_type.hpp index 2e3b09eae9..116804bf1c 100644 --- a/include/ck/utility/data_type.hpp +++ b/include/ck/utility/data_type.hpp @@ -1191,11 +1191,15 @@ struct vector_type()>> StaticallyIndexedArray d8x4_; StaticallyIndexedArray d16x2_; StaticallyIndexedArray d32x1_; - } data_; + } data_ = { .d32_ = {0} }; - __host__ __device__ constexpr vector_type() : data_{type{0}} {} + __attribute__((host)) __attribute__((device)) constexpr vector_type() { } + + __attribute__((host)) __attribute__((device)) constexpr vector_type(type v) { } - __host__ __device__ constexpr vector_type(type v) : data_{v} {} + // __host__ __device__ constexpr vector_type() : data_{type{0}} {} + + // __host__ __device__ constexpr vector_type(type v) : data_{v} {} template __host__ __device__ constexpr const auto& AsType() const diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp index 830662216a..0388798575 100644 --- a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -877,19 +877,11 @@ struct MoeSortingKernel } __syncthreads(); } - if (tid == 0) { - //temp hack ptr for expert tile cnt - p_total_tokens_post_pad[1] = 0; - } for(int i_e = tid; i_e < num_experts; i_e += block_size) { int e_start = smem_cumsum(i_e); int e_end = smem_cumsum(i_e + 1); - //temp hack ptr for expert tile cnt - index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + i_e; - p_sorted_expert_cnts[1] = unit_size_mdiv.div(e_end); - int expert_id = [&]() { if constexpr(Problem::LocalExpertMasking) { @@ -994,11 +986,18 @@ struct MoeSortingKernel __syncthreads(); } + if (tid == 0) { + //temp hack ptr for expert tile cnt + p_total_tokens_post_pad[1] = 0; + } // add the skip number for(int eid = tid; eid < num_experts; eid += block_size) { + //temp hack ptr for expert tile cnt + index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + eid; int e_start = smem_cumsum(eid); int e_end = smem_cumdup(eid + 1); + p_sorted_expert_cnts[1] = unit_size_mdiv.div(e_end); if constexpr(Problem::SkipExpertsWithZeroTokens) { if(e_start == e_end) // skip zero token expert @@ -1682,6 +1681,8 @@ struct MoeSortingMultiPhaseKernel_P2 if(position < kargs.num_experts) { + index_t *p_sorted_expert_cnts = p_total_tokens_post_pad + 1 + position;//temp mock for p_sorted_expert_cnts, fixme:felix + p_sorted_expert_cnts[0] = out_0; p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor; } @@ -1710,6 +1711,7 @@ struct MoeSortingMultiPhaseKernel_P2 { auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor; p_total_tokens_post_pad[0] = total_tokens_post_pad; + p_total_tokens_post_pad[kargs.num_experts+1] = prev_cumsum_a; //temp mock for p_sorted_expert_cnts, fixme:felix p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad; } } diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp index 6958d5e796..70c98efec0 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm.hpp @@ -91,7 +91,11 @@ struct ReferenceMoeGemm : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); +#else + v_a = i4 - 8; +#endif } else { @@ -106,7 +110,11 @@ struct ReferenceMoeGemm : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE v_b = i4_to_f32_gfx9(i4); +#else + v_b = i4 - 8; +#endif } else { diff --git a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp index 522e2233b9..3ee6f57ae9 100644 --- a/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp +++ b/library/include/ck/library/reference_tensor_operation/cpu/reference_moe_gemm2.hpp @@ -106,7 +106,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE v_a = i4_to_f32_gfx9(i4); +#else + v_a = i4 - 8; +#endif } else { @@ -120,7 +124,11 @@ struct ReferenceMoeGemm2 : public device::BaseOperator i4 = (i4x2 >> 0) & 0xf; else i4 = (i4x2 >> 4) & 0xf; +#if CK_USE_PK4_LAYOUT_SHUFFLE v_b = i4_to_f32_gfx9(i4); +#else + v_b = i4 - 8; +#endif } else { @@ -198,6 +206,7 @@ struct ReferenceMoeGemm2 : public device::BaseOperator return str.str(); } +#if CK_USE_PK4_LAYOUT_SHUFFLE static float i4_to_f32_gfx9(uint8_t i4) { static std::unordered_map u = {{0b1000, -0.5000f}, @@ -219,6 +228,8 @@ struct ReferenceMoeGemm2 : public device::BaseOperator return u[i4]; } +#endif + }; } // namespace host