diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index ff8762dce2..3e1f4c6268 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -28,10 +28,6 @@ struct GroupedConvolutionForwardInvoker static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) { - if(s.log_level_ > 0) - { - std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; - } // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index c4bc035a0f..f168d36cac 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -20,11 +20,6 @@ struct GroupedConvolutionForwardInvoker static float grouped_conv_fwd(const ck_tile::GroupedConvFwdHostArgs& args, const ck_tile::stream_config& s) { - if(s.log_level_ > 0) - { - std::cout << "[INVOKER] grouped_conv_fwd called, NDimSpatial=" << NDimSpatial << "\n"; - } - // Implicit GEMM Traits using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp index 80283a0467..7b05a328e0 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp @@ -58,6 +58,8 @@ struct InstanceTraits(); // 21. ADataType - oss << "," << detail::type_name(); // 22. BDataType - oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer - oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched - oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer - oss << "," << kNumWaveGroups; // 26. NumWaveGroups - oss << "," << detail::type_name(); // 27. AccDataType - oss << "," << detail::type_name(); // 28. EDataType - oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," << kExplicitGemm; // 12. ExplicitGemm + oss << "," << kMPerBlock; // 13. MPerBlock + oss << "," << kNPerBlock; // 14. NPerBlock + oss << "," << kKPerBlock; // 15. KPerBlock + oss << "," << kMWarp; // 16. MWarp + oss << "," << kNWarp; // 17. NWarp + oss << "," << kKWarp; // 18. KWarp + oss << "," << kMWarpTile; // 19. MWarpTile + oss << "," << kNWarpTile; // 20. NWarpTile + oss << "," << kKWarpTile; // 21. KWarpTile + oss << "," << detail::type_name(); // 22. ADataType + oss << "," << detail::type_name(); // 23. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer + oss << "," << kNumWaveGroups; // 27. NumWaveGroups + oss << "," << detail::type_name(); // 28. AccDataType + oss << "," << detail::type_name(); // 29. EDataType + oss << "," << detail::tuple_name(); // 30. DsDataType oss << "," - << detail::elementwise_op_name(); // 30. + << detail::elementwise_op_name(); // 31. // CDEElementwiseOperation oss << ">"; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp index f856a48e59..f911b8dc83 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp @@ -58,6 +58,8 @@ struct InstanceTraits(); // 21. ADataType - oss << "," << detail::type_name(); // 22. BDataType - oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer - oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched - oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer - oss << "," << kNumWaveGroups; // 26. NumWaveGroups - oss << "," << detail::type_name(); // 27. AccDataType - oss << "," << detail::type_name(); // 28. EDataType - oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," << kExplicitGemm; // 12. ExplicitGemm + oss << "," << kMPerBlock; // 13. MPerBlock + oss << "," << kNPerBlock; // 14. NPerBlock + oss << "," << kKPerBlock; // 15. KPerBlock + oss << "," << kMWarp; // 16. MWarp + oss << "," << kNWarp; // 17. NWarp + oss << "," << kKWarp; // 18. KWarp + oss << "," << kMWarpTile; // 19. MWarpTile + oss << "," << kNWarpTile; // 20. NWarpTile + oss << "," << kKWarpTile; // 21. KWarpTile + oss << "," << detail::type_name(); // 22. ADataType + oss << "," << detail::type_name(); // 23. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer + oss << "," << kNumWaveGroups; // 27. NumWaveGroups + oss << "," << detail::type_name(); // 28. AccDataType + oss << "," << detail::type_name(); // 29. EDataType + oss << "," << detail::tuple_name(); // 30. DsDataType oss << "," - << detail::elementwise_op_name(); // 30. + << detail::elementwise_op_name(); // 31. // CDEElementwiseOperation oss << ">"; diff --git a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp index c42a4f44dd..9db225db30 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_forward.hpp @@ -58,6 +58,8 @@ struct InstanceTraits(); // 21. ADataType - oss << "," << detail::type_name(); // 22. BDataType - oss << "," << GemmPipeline::GetPipelineName(); // 23. BlkGemmPipelineVer - oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 24. BlkGemmPipeSched - oss << "," << kDoubleSmemBuffer; // 25. DoubleSmemBuffer - oss << "," << kNumWaveGroups; // 26. NumWaveGroups - oss << "," << detail::type_name(); // 27. AccDataType - oss << "," << detail::type_name(); // 28. EDataType - oss << "," << detail::tuple_name(); // 29. DsDataType + oss << "," << kExplicitGemm; // 12. ExplicitGemm + oss << "," << kMPerBlock; // 13. MPerBlock + oss << "," << kNPerBlock; // 14. NPerBlock + oss << "," << kKPerBlock; // 15. KPerBlock + oss << "," << kMWarp; // 16. MWarp + oss << "," << kNWarp; // 17. NWarp + oss << "," << kKWarp; // 18. KWarp + oss << "," << kMWarpTile; // 19. MWarpTile + oss << "," << kNWarpTile; // 20. NWarpTile + oss << "," << kKWarpTile; // 21. KWarpTile + oss << "," << detail::type_name(); // 22. ADataType + oss << "," << detail::type_name(); // 23. BDataType + oss << "," << GemmPipeline::GetPipelineName(); // 24. BlkGemmPipelineVer + oss << "," << detail::pipeline_scheduler_name(kPipelineScheduler); // 25. BlkGemmPipeSched + oss << "," << kDoubleSmemBuffer; // 26. DoubleSmemBuffer + oss << "," << kNumWaveGroups; // 27. NumWaveGroups + oss << "," << detail::type_name(); // 28. AccDataType + oss << "," << detail::type_name(); // 29. EDataType + oss << "," << detail::tuple_name(); // 30. DsDataType oss << "," - << detail::elementwise_op_name(); // 30. + << detail::elementwise_op_name(); // 31. // CDEElementwiseOperation oss << ">"; diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index d6d4749db7..6b18095544 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -21,7 +21,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 4 /*VectorSizeB*/, 4 /*VectorSizeC*/, 1 /*NumGroupsToMerge*/, - false /*EnableSplitImage*/>; + false /*EnableSplitImage*/, + false /*ExplicitGemm*/>; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, @@ -106,6 +107,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) ",4" // VectorSizeC ",1" // NumGroupsToMerge ",0" // EnableSplitImage + ",0" // ExplicitGemm ",128" // MPerBlock ",128" // NPerBlock ",32" // KPerBlock diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index a6aee7b210..3ecd06e33d 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -123,7 +123,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 4 /*VectorSizeB*/, 4 /*VectorSizeC*/, 1 /*NumGroupsToMerge*/, - false /*EnableSplitImage*/>; + false /*EnableSplitImage*/, + false /*ExplicitGemm*/>; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, @@ -208,6 +209,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) ",4" // VectorSizeC ",1" // NumGroupsToMerge ",0" // EnableSplitImage + ",0" // ExplicitGemm ",128" // MPerBlock ",128" // NPerBlock ",32" // KPerBlock diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 1203686f6c..9da707bfec 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -734,7 +734,8 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 4 /*VectorSizeB*/, 4 /*VectorSizeC*/, 1 /*NumGroupsToMerge*/, - false /*EnableSplitImage*/>; + false /*EnableSplitImage*/, + false /*ExplicitGemm*/>; using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence<128 /*M_Tile*/, 128 /*N_Tile*/, 32 /*K_Tile*/>, @@ -818,6 +819,7 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) ",4" // VectorSizeC ",1" // NumGroupsToMerge ",0" // EnableSplitImage + ",0" // ExplicitGemm ",128" // MPerBlock ",128" // NPerBlock ",32" // KPerBlock diff --git a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp index 806a471397..eb7e3bcf94 100644 --- a/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp @@ -228,10 +228,34 @@ struct BatchedGemmKernel CDataType* c_ptr = static_cast(kargs.e_ptr) + batch_offset_C; // allocate LDS - __shared__ char smem_ptr[GetSmemSize()]; + __shared__ char smem_ptr0[GetSmemSize()]; - UniversalGemmKernel::RunGemm( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n); + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr1[GetSmemSize()]; + UniversalGemmKernel::RunGemm2LDS({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr0, + smem_ptr1, + kargs, + splitk_batch_offset, + i_m, + i_n); + } + else + { + UniversalGemmKernel::RunGemm({a_ptr}, + {b_ptr}, + {/*ds_ptr*/}, + c_ptr, + smem_ptr0, + kargs, + splitk_batch_offset, + i_m, + i_n); + } } }; diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index 86f5684e73..309860810c 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -511,7 +511,7 @@ template struct GroupedConvolutionBackwardDataKernel { - static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_; + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial; static constexpr ConvolutionSpecialization ConvSpecialization = GroupedConvTraitsType_::ConvSpecialization; using TilePartitioner = remove_cvref_t; @@ -556,6 +556,7 @@ struct GroupedConvolutionBackwardDataKernel static_assert(std::is_same_v, "Not supported!"); static_assert(std::is_same_v, "Not supported C GEMM layout!"); + static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -983,7 +984,7 @@ struct GroupedConvolutionBackwardDataKernel return group_id; } - CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const + CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized& kargs) const { const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); const index_t group_id = FindGroupId(kargs, blockIdX); diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 0143afae7a..7942d5e6e3 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -370,7 +370,7 @@ template struct GroupedConvolutionBackwardWeightKernel { - static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_; + static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial; static constexpr ConvolutionSpecialization ConvSpecialization = GroupedConvTraitsType_::ConvSpecialization; using TilePartitioner = remove_cvref_t; @@ -411,6 +411,9 @@ struct GroupedConvolutionBackwardWeightKernel static_assert(std::is_same_v, "Not supported!"); static_assert(std::is_same_v, "Not supported!"); static_assert(std::is_same_v, "Not supported!"); + static_assert(GroupedConvTraitsType_::ExplicitGemm == false || + GroupedConvTraitsType_::NumGroupsToMerge == 1, + "Not supported!"); [[nodiscard]] CK_TILE_HOST static const std::string GetName() { @@ -503,22 +506,6 @@ struct GroupedConvolutionBackwardWeightKernel index_t splitted_k; }; - CK_TILE_HOST static auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized& kargs, - const stream_config& s) - { - return [&]() { - if(kargs.k_batch > 1) - { - // Total number of convolution groups (ConvG) = GemmBatch * NumGroupsPerBatch - // since we require that ConvG % NumGroupsPerBatch == 0. - const auto wei_size = - kargs.GemmBatch * kargs.GemmM * kargs.GemmN * kargs.NumGroupsPerBatch; - hipGetErrorString( - hipMemsetAsync(kargs.wei_ptr, 0, wei_size * sizeof(WeiDataType), s.stream_id_)); - } - }; - } - CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized& kargs) { @@ -588,6 +575,14 @@ struct GroupedConvolutionBackwardWeightKernel } } + if constexpr(GroupedConvTraitsType_::ExplicitGemm && + ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + CK_TILE_ERROR( + "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!"); + return false; + } + namespace ctc = tensor_layout::convolution; if constexpr(std::is_same_v || std::is_same_v || @@ -886,61 +881,104 @@ struct GroupedConvolutionBackwardWeightKernel c_block_window, c_block_tile, d_block_window, smem_ptr_0); } - CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const + CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const { - const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); - const auto [iM, iN] = - TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); - const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); - const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); + static_assert(NumDTensor == 0, "Not supported!"); + using ExplicitBatchedGemmKernel = + BatchedGemmKernel; + const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{ + {{kargs.out_ptr}, + {kargs.in_ptr}, + {}, + kargs.wei_ptr, + kargs.GemmM, + kargs.GemmN, + kargs.GemmK, + {kargs.GemmM * kargs.GemmBatch}, + {kargs.GemmN * kargs.GemmBatch}, + {}, + kargs.GemmN, + kargs.k_batch}, + kargs.GemmM, + kargs.GemmN, + kargs.GemmM * kargs.GemmN, + kargs.GemmBatch}; + ExplicitBatchedGemmKernel{}(batched_gemm_kargs); + } - const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); - const index_t num_loop = amd_wave_read_first_lane( - ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock)); - const index_t i_k = - amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock); - - const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); - const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); - const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); - const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); - - // options - // conv_bwd_weight = Out * In = Weight - const OutDataType* a_ptr = static_cast(kargs.out_ptr) + group_offset_a; - const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b; - WeiDataType* c_ptr = static_cast(kargs.wei_ptr) + group_offset_c; - - __shared__ char smem_ptr_0[GetSmemSize()]; - - if constexpr(GemmPipeline::DoubleSmemBuffer == true) + CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const + { + if constexpr(GroupedConvTraitsType_::ExplicitGemm) { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) + CallExplicitGemm(kargs); + } + else + { + const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); + const auto [iM, iN] = + TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX); + const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); + + const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); + const index_t num_loop = amd_wave_read_first_lane(ck_tile::integer_divide_ceil( + kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock)); + const index_t i_k = + amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock); + + const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); + const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); + + // options + // conv_bwd_weight = Out * In = Weight + const OutDataType* a_ptr = + static_cast(kargs.out_ptr) + group_offset_a; + const InDataType* b_ptr = static_cast(kargs.in_ptr) + group_offset_b; + WeiDataType* c_ptr = static_cast(kargs.wei_ptr) + group_offset_c; + + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) { - RunGemm2LDS(a_ptr, + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + kargs.ds_ptr, + c_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + num_loop, + i_m, + i_n, + i_k); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, - smem_ptr_1, kargs, num_loop, i_m, i_n, i_k); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm( - a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k); + } } } } diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index a07ba1b05d..6d97f7b758 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -490,6 +490,9 @@ struct GroupedConvolutionForwardKernel static_assert(std::is_same_v, "Not supported!"); static_assert(std::is_same_v, "Not supported!"); static_assert(std::is_same_v, "Not supported!"); + static_assert(GroupedConvTraitsType_::ExplicitGemm == false || + GroupedConvTraitsType_::NumGroupsToMerge == 1, + "Not supported!"); // Helper struct for spatial coordinates struct SpatialCoords @@ -678,6 +681,14 @@ struct GroupedConvolutionForwardKernel } } + if constexpr(GroupedConvTraitsType_::ExplicitGemm && + ConvSpecialization != ConvolutionSpecialization::Filter1x1Stride1Pad0) + { + CK_TILE_ERROR( + "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!"); + return false; + } + namespace ctc = tensor_layout::convolution; if constexpr(std::is_same_v || std::is_same_v || @@ -974,135 +985,189 @@ struct GroupedConvolutionForwardKernel c_block_window, c_block_tile, d_block_window, smem_ptr_0); } - CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const + CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const { - const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); - const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); + static_assert(NumDTensor == 0, "Not supported!"); + using ExplicitBatchedGemmKernel = + BatchedGemmKernel; + const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{ + {{kargs.in_ptr}, + {kargs.wei_ptr}, + {}, + kargs.out_ptr, + kargs.GemmM, + kargs.GemmN, + kargs.GemmK, + {kargs.GemmK * kargs.GemmBatch}, + {kargs.GemmK}, + {}, + kargs.GemmBatch * kargs.GemmN, + kargs.k_batch}, + kargs.GemmK, + kargs.GemmN * kargs.GemmK, + kargs.GemmN, + kargs.GemmBatch}; + ExplicitBatchedGemmKernel{}(batched_gemm_kargs); + } - const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); - const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); - const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); - - // Split-N handling: Get which split this workgroup handles - const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); - - // Calculate batch offset for this split - const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split); - - // Calculate memory offsets for this split - const long_index_t input_batch_offset = static_cast(batch_offset) * - static_cast(kargs.input_batch_stride); - const long_index_t output_batch_offset = - static_cast(batch_offset) * - static_cast(kargs.output_batch_stride); - - // Calculate base pointers with group and batch offsets - const InDataType* base_a_ptr = - static_cast(kargs.in_ptr) + group_offset_a + input_batch_offset; - const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + - group_offset_b; // No batch offset for weights! - OutDataType* base_c_ptr = - static_cast(kargs.out_ptr) + group_offset_c + output_batch_offset; - - // Apply group offsets to D tensors - std::array ds_ptr_with_offsets; - static_for<0, NumDTensor, 1>{}([&](auto d) { - using DType = std::tuple_element_t; - ds_ptr_with_offsets[d] = - static_cast(kargs.ds_ptr[d]) + group_offset_c + output_batch_offset; - }); - - // ===================================================================== - // Split-image: Map local block to global tile index (if enabled) - // ===================================================================== - const InDataType* a_ptr; - OutDataType* c_ptr; - index_t i_m = 0; - index_t i_n = 0; - - // Pre-calculate block_id (used in both split-image and non-split paths) - const index_t block_id = static_cast(blockIdX); - - if constexpr(EnableSplitImage) + CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized& kargs) const + { + if constexpr(GroupedConvTraitsType_::ExplicitGemm) { - // Add spatial offsets for split-image (constexpr optimization) - a_ptr = base_a_ptr + kargs.spatial_offset_in; - c_ptr = base_c_ptr + kargs.spatial_offset_out; - - // Find which piece owns this block using binary search - // Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp - const index_t piece_id = - FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces); - const auto& piece = kargs.split_image.pieces[piece_id]; - const auto& split_info = kargs.split_image; - - // Calculate local block ID and tile indices - const index_t local_block_id = block_id - piece.block_start; - const index_t local_gemm_m = - kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size; - const auto [local_tile_m, local_tile_n] = - TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id); - - // Extract batch and spatial coordinates from local tile - const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock; - const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size; - const index_t local_n = local_m_start / spatial_per_batch; - const index_t local_spatial_flat = local_m_start % spatial_per_batch; - - // Convert to local spatial coordinates - const auto local_coords = - UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size); - - // Convert to global spatial coordinates - const index_t global_n = local_n; - const index_t global_d = piece.d_start + local_coords.d; - const index_t global_h = piece.h_start + local_coords.h; - const index_t global_w = piece.w_start + local_coords.w; - - // Convert to global M index - const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated - const index_t global_spatial_flat = FlattenSpatial( - global_d, global_h, global_w, split_info.total_h, split_info.total_w); - const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat; - - // Set tile indices for GEMM operation - i_m = amd_wave_read_first_lane(global_m); - i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock); + CallExplicitGemm(kargs); } else { - // No spatial offsets needed for regular path - a_ptr = base_a_ptr; - c_ptr = base_c_ptr; + const auto blockIdX = amd_wave_read_first_lane(blockIdx.x); + const auto blockIdY = amd_wave_read_first_lane(blockIdx.y); - // No split-image: use standard tile partitioning - const auto [iM, iN] = - TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id); - i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); - i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); - } + const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY); + const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY); + const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY); - // Use global descriptors for all cases - const auto& a_desc = kargs.a_grid_desc_m_k; - const auto& b_desc = kargs.b_grid_desc_n_k; - const auto& c_desc = kargs.c_grid_desc_m_n; + // Split-N handling: Get which split this workgroup handles + const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z); - // allocate LDS - __shared__ char smem_ptr_0[GetSmemSize()]; + // Calculate batch offset for this split + const index_t batch_offset = amd_wave_read_first_lane(blockIdZ * kargs.n_per_split); - if constexpr(GemmPipeline::DoubleSmemBuffer == true) - { - __shared__ char smem_ptr_1[GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) + // Calculate memory offsets for this split + const long_index_t input_batch_offset = + static_cast(batch_offset) * + static_cast(kargs.input_batch_stride); + const long_index_t output_batch_offset = + static_cast(batch_offset) * + static_cast(kargs.output_batch_stride); + + // Calculate base pointers with group and batch offsets + const InDataType* base_a_ptr = + static_cast(kargs.in_ptr) + group_offset_a + input_batch_offset; + const WeiDataType* b_ptr = static_cast(kargs.wei_ptr) + + group_offset_b; // No batch offset for weights! + OutDataType* base_c_ptr = + static_cast(kargs.out_ptr) + group_offset_c + output_batch_offset; + + // Apply group offsets to D tensors + std::array ds_ptr_with_offsets; + static_for<0, NumDTensor, 1>{}([&](auto d) { + using DType = std::tuple_element_t; + ds_ptr_with_offsets[d] = static_cast(kargs.ds_ptr[d]) + + group_offset_c + output_batch_offset; + }); + + // ===================================================================== + // Split-image: Map local block to global tile index (if enabled) + // ===================================================================== + const InDataType* a_ptr; + OutDataType* c_ptr; + index_t i_m = 0; + index_t i_n = 0; + + // Pre-calculate block_id (used in both split-image and non-split paths) + const index_t block_id = static_cast(blockIdX); + + if constexpr(EnableSplitImage) { - RunGemm2LDS(a_ptr, + // Add spatial offsets for split-image (constexpr optimization) + a_ptr = base_a_ptr + kargs.spatial_offset_in; + c_ptr = base_c_ptr + kargs.spatial_offset_out; + + // Find which piece owns this block using binary search + // Reference: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp + const index_t piece_id = + FindPieceId(block_id, kargs.split_image, kargs.num_spatial_pieces); + const auto& piece = kargs.split_image.pieces[piece_id]; + const auto& split_info = kargs.split_image; + + // Calculate local block ID and tile indices + const index_t local_block_id = block_id - piece.block_start; + const index_t local_gemm_m = + kargs.n_per_split * piece.d_size * piece.h_size * piece.w_size; + const auto [local_tile_m, local_tile_n] = + TilePartitioner{local_gemm_m, kargs.GemmN}.GetOutputTileIndex(local_block_id); + + // Extract batch and spatial coordinates from local tile + const index_t local_m_start = local_tile_m * TilePartitioner::MPerBlock; + const index_t spatial_per_batch = piece.d_size * piece.h_size * piece.w_size; + const index_t local_n = local_m_start / spatial_per_batch; + const index_t local_spatial_flat = local_m_start % spatial_per_batch; + + // Convert to local spatial coordinates + const auto local_coords = + UnflattenSpatial(local_spatial_flat, piece.h_size, piece.w_size); + + // Convert to global spatial coordinates + const index_t global_n = local_n; + const index_t global_d = piece.d_start + local_coords.d; + const index_t global_h = piece.h_start + local_coords.h; + const index_t global_w = piece.w_start + local_coords.w; + + // Convert to global M index + const index_t global_spatial_per_batch = split_info.total_spatial; // Pre-calculated + const index_t global_spatial_flat = FlattenSpatial( + global_d, global_h, global_w, split_info.total_h, split_info.total_w); + const index_t global_m = global_n * global_spatial_per_batch + global_spatial_flat; + + // Set tile indices for GEMM operation + i_m = amd_wave_read_first_lane(global_m); + i_n = amd_wave_read_first_lane(local_tile_n * TilePartitioner::NPerBlock); + } + else + { + // No spatial offsets needed for regular path + a_ptr = base_a_ptr; + c_ptr = base_c_ptr; + + // No split-image: use standard tile partitioning + const auto [iM, iN] = + TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(block_id); + i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); + i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); + } + + // Use global descriptors for all cases + const auto& a_desc = kargs.a_grid_desc_m_k; + const auto& b_desc = kargs.b_grid_desc_n_k; + const auto& c_desc = kargs.c_grid_desc_m_n; + + // allocate LDS + __shared__ char smem_ptr_0[GetSmemSize()]; + + if constexpr(GemmPipeline::DoubleSmemBuffer == true) + { + __shared__ char smem_ptr_1[GetSmemSize()]; + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + RunGemm2LDS(a_ptr, + b_ptr, + ds_ptr_with_offsets, + c_ptr, + smem_ptr_0, + smem_ptr_1, + a_desc, + b_desc, + c_desc, + kargs.GemmK, + i_m, + i_n, + kargs.elfunc); + } + } + else + { + if constexpr(!(EpiloguePipeline::MemoryOperation == + memory_operation_enum::atomic_add && + GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + RunGemm(a_ptr, b_ptr, ds_ptr_with_offsets, c_ptr, smem_ptr_0, - smem_ptr_1, a_desc, b_desc, c_desc, @@ -1110,26 +1175,7 @@ struct GroupedConvolutionForwardKernel i_m, i_n, kargs.elfunc); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - RunGemm(a_ptr, - b_ptr, - ds_ptr_with_offsets, - c_ptr, - smem_ptr_0, - a_desc, - b_desc, - c_desc, - kargs.GemmK, - i_m, - i_n, - kargs.elfunc); + } } } } diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 8ea6cffa7d..27349a0978 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -63,7 +63,8 @@ template + bool EnableSplitImage_ = false, + bool ExplicitGemm_ = false> struct GroupedConvTraits { private: @@ -89,8 +90,9 @@ struct GroupedConvTraits using ELayout = ck_tile::tensor_layout::gemm::RowMajor; }; // Compile time parameters - static constexpr bool EnableSplitImage = EnableSplitImage_; static constexpr index_t NumGroupsToMerge = NumGroupsToMerge_; + static constexpr bool EnableSplitImage = EnableSplitImage_; + static constexpr bool ExplicitGemm = ExplicitGemm_; static constexpr index_t NDimSpatial = NDimSpatial_; static constexpr ConvolutionSpecialization ConvSpecialization = ConvSpecialization_; using InLayout = InLayout_;