From 91eb136937db8e7ff31b95e9c604f78218a6dff6 Mon Sep 17 00:00:00 2001 From: "Ding, Yi" Date: Mon, 26 May 2025 10:32:26 +0000 Subject: [PATCH] Fix v1; use M padding --- .../blockwise_gemm_pipeline_xdlops_v1_mx.hpp | 26 +++++++------- .../impl/device_gemm_xdl_cshuffle_v3_mx.hpp | 32 +++++++++++++---- .../grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp | 36 +++++++++++++------ .../device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp | 13 +++---- ...dl_f4_f4_f16_mk_nk_mn_default_instance.cpp | 2 +- 5 files changed, 72 insertions(+), 37 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp index 761f676a91..308a73a178 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp @@ -253,8 +253,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx 15 32 --> 47 | 64 --> 79 96 --> 111 | etc. @@ -387,6 +382,13 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto m0) { static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { @@ -507,9 +509,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_mx{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index d48e4dca0f..d088d3775d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -327,13 +327,31 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 1) && tail_num_choice.value == tail_num) { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< // - GridwiseGemm, - mainloop_choice.value, - CGlobalMemoryDataOperation, - minimum_occupancy, - tail_num_choice.value>; - Run(kernel); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + const auto kernel = kernel_gemm_xdl_cshuffle_v3< // + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; + Run(kernel); + } + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< // + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; + Run(kernel); + } + else + { + static_assert(false, "Unexpected BlkGemmPipelineVer!"); + } } }); return ave_time; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 89cafe2fca..f124152656 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -331,12 +331,28 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // pad M, but not K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), make_right_pad_transform(M, MPad - M)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return a_grid_desc_ak0_m_ak1; + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MPad, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(MPad), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + return a_grid_desc; } else if constexpr(GemmSpec == GemmSpecialization::KPadding || GemmSpec == GemmSpecialization::NKPadding) @@ -408,8 +424,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 GemmSpec != GemmSpecialization::Default), "pk_i4_t does not support padding"); static_assert(!(is_same_v, f4x2_pk_t> && - GemmSpec != GemmSpecialization::Default), - "f4x2_pk_t does not support padding"); + (GemmSpec != GemmSpecialization::Default && + GemmSpec != GemmSpecialization::MPadding)), + "f4x2_pk_t does not support K padding"); if constexpr(GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding) @@ -1357,6 +1374,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0, "B scale pack data type too large!"); + static_assert(is_same_v && + is_same_v, + "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!"); + template ( p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1895,9 +1914,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - static_assert( - is_same_v && - is_same_v); const CElementwiseOperation c_element_op{}; // divide block work by [M, N] diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp index 48b38524ea..2b4c18787a 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp @@ -29,6 +29,7 @@ using PassThrough = element_wise::PassThrough; static constexpr auto GemmDefault = GemmSpecialization::Default; static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMPadding = GemmSpecialization::MPadding; static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; @@ -44,12 +45,12 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple< //#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp index 597879c414..8045413c99 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instance.cpp @@ -23,7 +23,7 @@ void add_device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_default_instances( PassThrough>>>& instances) { add_device_operation_instances( - instances, device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances{}); + instances, device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances{}); } } // namespace instance