diff --git a/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_bf16.conf index 16a93f0066..6878b029b8 100644 --- a/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_bf16.conf @@ -80,3 +80,10 @@ DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf16,bf16,fp32,EmptyTuple,bf16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,64,16,16,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,16,4,true,Seq(4,4,16),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,bf16,bf16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf16,bf16,fp32,EmptyTuple,bf16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,32,8,8,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,4,true,Seq(4,8,8),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,bf16,bf16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf16,bf16,fp32,EmptyTuple,bf16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,16,4,4,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,4),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,bf16,bf16,1,1> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,256,128,64,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,4,1,0,0,256,128,64,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,4,8,8,1,0,0,256,128,32,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,128,64,64,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,64,64,64,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,4,1,0,0,64,64,64,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,32,0,0,256,32,64,4,1,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> diff --git a/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_fp16.conf index 39893398a0..63ee35926f 100644 --- a/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/backward_data/profiler/nhwgc_fp16.conf @@ -80,3 +80,10 @@ DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp16,fp16,fp32,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,64,16,16,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,16,4,true,Seq(4,4,16),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,fp16,fp16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp16,fp16,fp32,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,32,8,8,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,4,true,Seq(4,8,8),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,fp16,fp16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp16,fp16,fp32,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,16,4,4,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,16,4),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,fp16,fp16,1,1> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,256,128,64,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,4,1,0,0,256,128,64,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,4,8,8,1,0,0,256,128,32,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,128,64,64,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,64,64,64,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,4,1,0,0,64,64,64,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,32,0,0,256,32,64,4,1,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> diff --git a/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_bf16.conf b/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_bf16.conf index f46c741ee6..2aab6f5ac3 100644 --- a/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_bf16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_bf16.conf @@ -14,3 +14,4 @@ DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf16,bf16,fp32,EmptyTuple,bf16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,128,32,16,4,4,32,32,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,8,4),Seq(0,2,1),Seq(0,2,1),1,2,1,true,1,1,Seq(1,16,1,16),2,1,Default,bf16,bf16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf16,bf16,fp32,EmptyTuple,bf16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,32,8,8,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,4,true,Seq(4,2,8),Seq(0,2,1),Seq(0,2,1),1,8,1,true,1,1,Seq(1,64,1,4),4,1,Default,bf16,bf16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,bf16,bf16,fp32,EmptyTuple,bf16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,64,16,16,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,16,4,true,Seq(4,4,16),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,bf16,bf16,1,1> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,64,64,64,2,2,1,16,16,32,bf16,bf16,WAVELET,Intrawave,0,1,fp32,bf16,EmptyTuple,PassThrough> diff --git a/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_fp16.conf b/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_fp16.conf index adeb3b5ef3..fba1b97b1b 100644 --- a/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_fp16.conf +++ b/experimental/grouped_convolution_tile_instances/configs/backward_data/tests/nhwgc_fp16.conf @@ -14,3 +14,4 @@ DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp16,fp16,fp32,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,128,32,16,4,4,32,32,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,4,4,true,Seq(4,8,4),Seq(0,2,1),Seq(0,2,1),1,2,1,true,1,1,Seq(1,16,1,16),2,1,Default,fp16,fp16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp16,fp16,fp32,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,32,8,8,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,8,4,true,Seq(4,2,8),Seq(0,2,1),Seq(0,2,1),1,8,1,true,1,1,Seq(1,64,1,4),4,1,Default,fp16,fp16,1,1> DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle<2,NHWGK,GKYXC,EmptyTuple,NHWGC,fp16,fp16,fp32,EmptyTuple,fp16,PassThrough,PassThrough,PassThrough,Filter1x1Stride1Pad0,1,1,1,256,64,16,64,16,16,16,16,1,1,Seq(4,64,1),Seq(1,0,2),Seq(1,0,2),2,16,4,true,Seq(4,4,16),Seq(0,2,1),Seq(0,2,1),1,1,1,true,1,1,Seq(1,16,1,16),1,1,Default,fp16,fp16,1,1> +GroupedConvolutionBackwardDataKernel<2,Default,NHWGK,GKYXC,EmptyTuple,NHWGC,8,8,8,1,0,0,64,64,64,2,2,1,16,16,32,fp16,fp16,WAVELET,Intrawave,0,1,fp32,fp16,EmptyTuple,PassThrough> diff --git a/experimental/grouped_convolution_tile_instances/generate_instances.py b/experimental/grouped_convolution_tile_instances/generate_instances.py index ba72c58403..d7183441df 100755 --- a/experimental/grouped_convolution_tile_instances/generate_instances.py +++ b/experimental/grouped_convolution_tile_instances/generate_instances.py @@ -281,17 +281,21 @@ STREAMK_REDUCTION_STRATEGY = { } -def parse_native_bwd_weight_instance(args, instance_id, problem_name): - """Parse a native CK Tile instance string (GroupedConvolutionBackwardWeightKernel<...>). +def parse_native_instance(args, instance_id, problem_name, has_streamk, has_two_stage): + """Parse a native CK Tile grouped-conv instance string for any direction + (GroupedConvolution{Forward,BackwardData,BackwardWeight}Kernel<...>). - Fields (0-indexed after splitting on commas inside <>): + Fields (0-indexed after splitting on commas inside <>), shared by all directions: 0: NDimSpatial, 1: ConvSpec, 2: InLayout, 3: WeiLayout, 4: DsLayout, 5: OutLayout, 6: VecA, 7: VecB, 8: VecC, 9: NumGroupsToMerge, 10: SplitImage, 11: ExplicitGemm, 12: MPerBlock, 13: NPerBlock, 14: KPerBlock, 15: MWarp, 16: NWarp, 17: KWarp, 18: MWarpTile, 19: NWarpTile, 20: KWarpTile, 21: ADataType, 22: BDataType, 23: PipelineName, 24: Scheduler, 25: DoubleSmemBuffer, 26: NumWaveGroups, 27: AccDataType, 28: EDataType, 29: DsDataType, 30: CDEElementwiseOp, - 31: IsStreamK, [32: ReductionStrategy, 33: PersistentDP] + [31: IsStreamK, 32: ReductionStrategy, 33: PersistentDP] (backward_weight only) + + has_streamk: direction carries the trailing StreamK fields (backward_weight only). + has_two_stage: direction has a two-stage path (backward_weight only); else False. """ spec = args[1] tile_size = [int(args[12]), int(args[13]), int(args[14])] @@ -314,10 +318,14 @@ def parse_native_bwd_weight_instance(args, instance_id, problem_name): split_image = int(args[10]) != 0 explicit_gemm = int(args[11]) != 0 - is_streamk = int(args[31]) != 0 + is_two_stage = ( + has_two_stage + and get_dtype(problem_name) != "float" + and scalar_per_vector[2] == 1 + ) + is_streamk = has_streamk and int(args[31]) != 0 streamk_reduction_strategy = None streamk_persistent = False - is_two_stage = get_dtype(problem_name) != "float" and scalar_per_vector[2] == 1 if is_streamk: is_two_stage = False reduction_int = int(args[32]) @@ -347,59 +355,21 @@ def parse_native_bwd_weight_instance(args, instance_id, problem_name): ) -def parse_native_fwd_instance(args, instance_id, _): - """Parse a native CK Tile forward conv instance string - (GroupedConvolutionForwardKernel<...>). +def parse_native_bwd_weight_instance(args, instance_id, problem_name): + return parse_native_instance( + args, instance_id, problem_name, has_streamk=True, has_two_stage=True + ) - Same field layout as backward_weight (fields 0-30) but with no trailing - StreamK fields. Forward has no two-stage path, so two_stage is always False. - """ - spec = args[1] - tile_size = [int(args[12]), int(args[13]), int(args[14])] - warps = [int(args[15]), int(args[16]), int(args[17])] - warp_tile = [int(args[18]), int(args[19]), int(args[20])] - pipeline_name = args[23] - if pipeline_name not in PIPELINE_NAME_TO_VERSION: - raise RuntimeError( - f"Unknown pipeline name '{pipeline_name}' in native instance {instance_id}" - ) - pipeline_version = PIPELINE_NAME_TO_VERSION[pipeline_name] - - scheduler = args[24] - double_smem_buffer = int(args[25]) != 0 - num_wave_groups = int(args[26]) - - scalar_per_vector = [int(args[6]), int(args[7]), int(args[8])] - num_groups_to_merge = int(args[9]) - split_image = int(args[10]) != 0 - explicit_gemm = int(args[11]) != 0 - - return ConvInstanceTemplateParams( - spec, - tile_size, - warps, - warp_tile, - double_smem_buffer, - num_wave_groups, - False, # forward has no two-stage path - pipeline_version, - scheduler, - scalar_per_vector, - num_groups_to_merge, - split_image, - explicit_gemm, - instance_id, - streamk_enabled=False, - streamk_reduction_strategy=None, - streamk_persistent=False, +def parse_native_fwd_instance(args, instance_id, problem_name): + return parse_native_instance( + args, instance_id, problem_name, has_streamk=False, has_two_stage=False ) def parse_native_bwd_data_instance(args, instance_id, problem_name): - """Parse a native CK Tile backward data instance string.""" - raise NotImplementedError( - "Native backward data instance parsing is not yet implemented." + return parse_native_instance( + args, instance_id, problem_name, has_streamk=False, has_two_stage=False ) diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index a4419ff0c1..279104ecdb 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -529,7 +529,9 @@ struct GroupedConvolutionBackwardDataKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; - static constexpr index_t kBlockSize = GemmPipeline::BlockSize; + // Wavelet pipelines launch extra load waves (LaunchBlockSize > BlockSize); others use + // BlockSize. See GroupedConvLaunchBlockSize in grouped_convolution_utils.hpp. + static constexpr index_t kBlockSize = GroupedConvLaunchBlockSize; using OutDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; @@ -934,29 +936,31 @@ struct GroupedConvolutionBackwardDataKernel const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - // Run Epilogue Pipeline with k_batch dispatch - if(k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, group_id, block_idx_m, block_idx_n); - - EpiloguePipeline{} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - else - { - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) + // Run the epilogue with split-K dispatch, wrapped for wavelet load/math waves. + RunWaveletAwareEpilogue([&]() { + if(k_batch == 1) { - auto c_block_window = MakeCBlockWindow( + auto c_block_window = MakeCBlockWindow( c_ptr, kargs, group_id, block_idx_m, block_idx_n); EpiloguePipeline{} .template operator()( c_block_window, c_block_tile, d_block_window, smem_ptr_0); } - } + else + { + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value)) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + } + }); } CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 3acd703d13..30e46ad0a0 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -456,21 +456,9 @@ struct GroupedConvolutionBackwardWeightKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; - // For wavelet, LaunchBlockSize > BlockSize. Use LaunchBlockSize for kernel launch. - template - struct has_launch_block_size : std::false_type - { - }; - template - struct has_launch_block_size> : std::true_type - { - }; - static constexpr index_t kBlockSize = []() { - if constexpr(has_launch_block_size::value) - return GemmPipeline::LaunchBlockSize; - else - return GemmPipeline::BlockSize; - }(); + // Wavelet pipelines launch extra load waves (LaunchBlockSize > BlockSize); others use + // BlockSize. See GroupedConvLaunchBlockSize in grouped_convolution_utils.hpp. + static constexpr index_t kBlockSize = GroupedConvLaunchBlockSize; using OutDataType = remove_cvref_t; using InDataType = remove_cvref_t; @@ -1061,22 +1049,6 @@ struct GroupedConvolutionBackwardWeightKernel {block_idx_k, block_idx_m}); } - // SFINAE helper: detect GemmPipeline::IsWavelet - template - struct has_is_wavelet : std::false_type - { - }; - template - struct has_is_wavelet> : std::true_type - { - }; - static constexpr bool kIsWavelet = []() { - if constexpr(has_is_wavelet::value) - return GemmPipeline::IsWavelet; - else - return false; - }(); - /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * @@ -1109,38 +1081,8 @@ struct GroupedConvolutionBackwardWeightKernel const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - if constexpr(kIsWavelet) - { - // Wavelet: math waves run the epilogue, load waves run matching barriers - if(GemmPipeline::IsMathWave()) - { - if(kargs.k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - else - { - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value)) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); - } - } - } - else - { - // Load waves: match epilogue barrier count to avoid deadlock - EpiloguePipeline::RunBarrierStub(); - } - } - else - { - // Standard (non-wavelet) path + // Run the epilogue with split-K dispatch, wrapped for wavelet load/math waves. + RunWaveletAwareEpilogue([&]() { if(kargs.k_batch == 1) { auto c_block_window = MakeCBlockWindow( @@ -1159,7 +1101,7 @@ struct GroupedConvolutionBackwardWeightKernel EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); } } - } + }); } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 48979b09a2..a838503962 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -572,38 +572,9 @@ struct GroupedConvolutionForwardKernel using GemmDsLayout = remove_cvref_t; static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor; - // For wavelet, LaunchBlockSize > BlockSize (extra load-only waves). Use - // LaunchBlockSize for the kernel launch; non-wavelet pipelines fall back to BlockSize. - template - struct has_launch_block_size : std::false_type - { - }; - template - struct has_launch_block_size> : std::true_type - { - }; - static constexpr index_t kBlockSize = []() { - if constexpr(has_launch_block_size::value) - return Pipeline::LaunchBlockSize; - else - return Pipeline::BlockSize; - }(); - - // SFINAE helper: detect Pipeline::IsWavelet (load/math wave specialization). - template - struct has_is_wavelet : std::false_type - { - }; - template - struct has_is_wavelet> : std::true_type - { - }; - static constexpr bool kIsWavelet = []() { - if constexpr(has_is_wavelet::value) - return Pipeline::IsWavelet; - else - return false; - }(); + // Wavelet pipelines launch extra load waves (LaunchBlockSize > BlockSize); others use + // BlockSize. See GroupedConvLaunchBlockSize in grouped_convolution_utils.hpp. + static constexpr index_t kBlockSize = GroupedConvLaunchBlockSize; using InDataType = remove_cvref_t; using WeiDataType = remove_cvref_t; @@ -1375,14 +1346,11 @@ struct GroupedConvolutionForwardKernel const auto& c_block_tile = Pipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline with k_batch dispatching - if constexpr(kIsWavelet) - { - // Wavelet: only math waves hold accumulators and run the epilogue. - // Load waves run a matching barrier sequence to avoid LDS-sync deadlock. - // Forward has no split-K (IsSplitKSupported == false), so only the - // memory_operation_enum::set path is reachable. - if(Pipeline::IsMathWave()) + // Run the epilogue with k_batch dispatch, wrapped for wavelet load/math waves. + // Forward has no split-K (IsSplitKSupported == false), so the atomic_add branch + // compiles out and only the set path is reachable. + RunWaveletAwareEpilogue([&]() { + if(k_batch == 1) { auto c_block_window = MakeCBlockWindow( c_ptr, c_desc, block_idx_m, block_idx_n); @@ -1393,32 +1361,19 @@ struct GroupedConvolutionForwardKernel } else { - EpiloguePipeline::RunBarrierStub(); - } - } - else if(k_batch == 1) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, c_desc, block_idx_m, block_idx_n); + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + is_any_of::value) && + IsSplitKSupported) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, ds_block_window, smem_ptr_0); - } - else - { - if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && - is_any_of::value) && - IsSplitKSupported) - { - auto c_block_window = MakeCBlockWindow( - c_ptr, c_desc, block_idx_m, block_idx_n); - - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } - } + }); } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const diff --git a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp index 6c840d212b..8feb9e4089 100644 --- a/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp +++ b/include/ck_tile/ops/grouped_convolution/utils/grouped_convolution_utils.hpp @@ -21,6 +21,71 @@ enum class GroupedConvDirection BACKWARD_WEIGHT }; +// Wavelet pipeline support shared by all three grouped-conv directions. The wavelet GEMM +// pipeline launches extra load-only waves (LaunchBlockSize > BlockSize) and splits the +// workgroup into math waves (hold accumulators, run the epilogue) and load waves (run a +// matching barrier sequence). Non-wavelet pipelines expose neither member; these helpers +// detect that via SFINAE so each kernel dispatches without duplicating the machinery. +namespace impl { +template +struct has_launch_block_size : std::false_type +{ +}; +template +struct has_launch_block_size> : std::true_type +{ +}; + +template +struct has_is_wavelet : std::false_type +{ +}; +template +struct has_is_wavelet> : std::true_type +{ +}; +} // namespace impl + +// Block size to launch with: wavelet pipelines need LaunchBlockSize (load + math waves); +// all others fall back to BlockSize. +template +inline constexpr index_t GroupedConvLaunchBlockSize = []() { + if constexpr(impl::has_launch_block_size::value) + return Pipeline::LaunchBlockSize; + else + return Pipeline::BlockSize; +}(); + +// True when the pipeline uses wavelet load/math wave specialization. +template +inline constexpr bool is_wavelet_pipeline = []() { + if constexpr(impl::has_is_wavelet::value) + return Pipeline::IsWavelet; + else + return false; +}(); + +// Run the CShuffle epilogue with wavelet load/math wave dispatch. For wavelet pipelines only +// the math waves run @p epilogue_body (which writes the C tile); load waves run a matching +// barrier sequence (RunBarrierStub) to avoid an LDS-sync deadlock. Non-wavelet pipelines run +// @p epilogue_body directly. The body is direction-specific (split-K dispatch, window +// construction), so it is passed in rather than shared. +template +CK_TILE_DEVICE void RunWaveletAwareEpilogue(EpilogueBody&& epilogue_body) +{ + if constexpr(is_wavelet_pipeline) + { + if(GemmPipeline::IsMathWave()) + epilogue_body(); + else + EpiloguePipeline::RunBarrierStub(); + } + else + { + epilogue_body(); + } +} + /// @brief The Grouped Conv kernel host arguments. /// /// @par Overview