diff --git a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index 00518b369f..72c011bfb2 100644 --- a/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -153,7 +153,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle( const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_as_grid + a_batch_offset, p_bs_grid + b_batch_offset, p_ds_grid_grp, @@ -439,7 +439,7 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = ck::conditional_t, ADataType>; using GemmBDataType = ck::conditional_t, BDataType>; -#define GridwiseGemmTemplateParameters \ +#define GridwiseGemmMultiABDTemplateParameters \ GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -454,11 +454,26 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched // Use appropriate gridwise gemm - using GridwiseGemm = - ck::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm = ck::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers. using APointers = ck::conditional_t&, const void*>; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp index d53fbca4ea..fc1a2b995a 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -80,19 +80,20 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -556,7 +557,6 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp index 25a9d7f96d..0cd1d84a43 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp @@ -88,19 +88,20 @@ __global__ void __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - ck::Tuple<>{}, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ck::Tuple<>{}, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + ck::Tuple<>{}, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ck::Tuple<>{}, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -344,7 +345,6 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp index 630f143260..12085edaae 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -107,19 +107,20 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_k0_m_k1, - b_grid_desc_k0_n_k1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_k0_m_k1, + b_grid_desc_k0_n_k1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -336,7 +337,6 @@ struct DeviceBatchedGemmMultiD_Xdl : public DeviceBatchedGemmMultiD(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -324,7 +325,6 @@ struct DeviceContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp index 3fae3a3765..6c4195e75d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -57,19 +57,20 @@ __global__ void #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - GridwiseGemm::template Run(p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + GridwiseGemm::template Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); #else ignore = p_a_grid; ignore = p_b_grid; @@ -257,7 +258,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD( + GridwiseGemm::template Run( contraction_arg_ptr[group_id].p_a_grid_, contraction_arg_ptr[group_id].p_b_grid_, contraction_arg_ptr[group_id].p_ds_grid_, @@ -368,7 +368,6 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 41f596d160..f18ce40fc5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -15,7 +15,6 @@ #include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" #include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp" @@ -71,7 +70,8 @@ template + bool HasMainKBlockLoop, + InMemoryDataOperationEnum OutElementOp> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -92,12 +92,14 @@ __global__ void e_grid_desc_mblock_mperblock_nblock_nperblock_, const Block2ETileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n) + const ComputePtrOffsetOfN compute_ptr_offset_of_n, + const index_t KBatch) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // offset base pointer for each work-group - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z / KBatch); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z - n_idx * KBatch); const long_index_t a_batch_offset = amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); @@ -123,19 +125,22 @@ __global__ void static_for<0, NumDTensor, 1>{}( [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; }); - GridwiseGemm::template Run(p_a_grid + a_batch_offset + a_n_offset, - p_b_grid + b_batch_offset, - p_ds_grid_grp, - p_e_grid + e_batch_offset + e_n_offset, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock_, - block_2_ctile_map); + GridwiseGemm::template Run( + p_a_grid + a_batch_offset + a_n_offset, + p_b_grid + b_batch_offset, + p_ds_grid_grp, + p_e_grid + e_batch_offset + e_n_offset, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock_, + block_2_ctile_map, + KBatch, + k_idx); #else ignore = p_a_grid; ignore = p_b_grid; @@ -154,151 +159,6 @@ __global__ void #endif } -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t num_k_per_block) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - // offset base pointer for each work-group - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); - const index_t k_idx = - __builtin_amdgcn_readfirstlane((blockIdx.y - n_idx * karg.KBatch) * num_k_per_block); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); -#else - ignore = karg; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = compute_ptr_offset_of_batch; - ignore = compute_ptr_offset_of_n; - ignore = num_k_per_block; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) -} - -template -__global__ void -#if CK_USE_LAUNCH_BOUNDS - __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds( - typename GridwiseGemm::Argument karg, - const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, - const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, - const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock - c_grid_desc_mblock_mperblock_nblock_nperblock, - const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, - const ComputePtrOffsetOfN compute_ptr_offset_of_n, - const index_t num_k_per_block) -{ -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) - const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z); - const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.y / karg.KBatch); - const index_t k_idx = - __builtin_amdgcn_readfirstlane((blockIdx.y - n_idx * karg.KBatch) * num_k_per_block); - - const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); - const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); - const long_index_t e_batch_offset = __builtin_amdgcn_readfirstlane( - static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); - - const long_index_t a_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - const long_index_t e_n_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx)); - - // Pass two lds pointer is the key to tell compiler that ds_read/write - // operate on different lds chunk at same time without order dependecy - __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - GridwiseGemm::template Run_2Lds(karg.p_a_grid + a_batch_offset + a_n_offset, - karg.p_b_grid + b_batch_offset, - karg.p_c_grid + e_batch_offset + e_n_offset, - p_shared_0, - p_shared_1, - karg, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - c_grid_desc_mblock_mperblock_nblock_nperblock, - k_idx); -#else - ignore = karg; - ignore = a_grid_desc_ak0_m_ak1; - ignore = b_grid_desc_bk0_n_bk1; - ignore = c_grid_desc_mblock_mperblock_nblock_nperblock; - ignore = compute_ptr_offset_of_batch; - ignore = compute_ptr_offset_of_n; - ignore = num_k_per_block; -#endif // end of if (defined(__gfx908__) || defined(__gfx90a__)) -} - } // namespace // Conv backward data multiple D: @@ -358,9 +218,7 @@ template + index_t MaxTransposeTransferOutScalarPerVector = 1> struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 : public DeviceGroupedConvBwdDataMultipleD 0; static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; static constexpr bool IsSplitKSupported = (CDEBlockTransferScalarPerVector_NPerBlock % 2 == 0 || sizeof(EDataType) % 4 == 0) && @@ -473,59 +330,25 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // GridwiseGemm #define GridwiseGemmMultiDTemplateParams \ ABDataType, ABDataType, AComputeType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ - AElementwiseOp, BElementwiseOp, CDEElementwiseOp, InMemoryDataOperationEnum::Set, \ - NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ - NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ - ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ - ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ - ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ - BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ - BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ - BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ - BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + AElementwiseOp, BElementwiseOp, CDEElementwiseOp, NumGemmKPrefetchStage, BlockSize, \ + MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ + ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ + ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ + ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ + BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ + BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ + BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ + CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, BComputeType - -#define GridwiseGemmTemplateParams \ - tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, tensor_layout::gemm::RowMajor, \ - ADataType, BDataType, AccDataType, CShuffleDataType, EDataType, AElementwiseOp, \ - BElementwiseOp, CDEElementwiseOp, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, \ - AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ - CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ - CDEBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, \ - AComputeType, BComputeType - - using GridwiseGemm = - std::conditional_t, - GridwiseGemm_xdl_cshuffle_v3>; + using GridwiseGemm = GridwiseGemmMultipleD_xdl_cshuffle; template static auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N e_grid_desc_m_n) { - if constexpr(isMultiD) - { - return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n); - } - else - { - const index_t M = e_grid_desc_m_n.GetLength(I0); - const index_t N = e_grid_desc_m_n.GetLength(I1); - return GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n, - GridwiseGemm::CalculateMBlock(M), - GridwiseGemm::CalculateNBlock(N)); - } + return GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); } template @@ -850,46 +673,34 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const auto b_grid_desc_n_k = transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); - if constexpr(isMultiD) - { - a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); - b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); - ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); - e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); - } + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n); + e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n); // desc for blockwise copy a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1); b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1); - if constexpr(isMultiD) + // block-to-e-tile-map + auto block_2_etile_map = + GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + + block_2_etile_map_container_.push_back(block_2_etile_map); + + if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, + b_grid_desc_n_k, + ds_grid_desc_m_n, + e_grid_desc_m_n, + block_2_etile_map, + k_batch_)) { - // block-to-e-tile-map - auto block_2_etile_map = - GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n); + ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - block_2_etile_map_container_.push_back(block_2_etile_map); + GridwiseGemm:: + MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + ds_grid_desc_m_n)); - if(GridwiseGemm::CheckValidity(a_grid_desc_m_k, - b_grid_desc_n_k, - ds_grid_desc_m_n, - e_grid_desc_m_n, - block_2_etile_map)) - { - ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - - GridwiseGemm:: - MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - ds_grid_desc_m_n)); - - e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( - MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( - e_grid_desc_m_n)); - } - } - else - { - // there is no need to check since M, N, K are padded e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n)); @@ -1083,12 +894,13 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { using Argument = DeviceOp::Argument; + template float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; const index_t gdy = arg.num_group_; - const index_t gdz = arg.num_workgroups_per_Conv_N_; + const index_t gdz = arg.num_workgroups_per_Conv_N_ * arg.k_batch_; const ADataType* p_a_grid = arg.p_a_grid_; const BDataType* p_b_grid = arg.p_b_grid_; @@ -1117,7 +929,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.b_grid_desc_n_k_container_[i], arg.ds_grid_desc_m_n_container_[i], arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) + arg.block_2_etile_map_container_[i], + arg.k_batch_)) { throw std::runtime_error("wrong! device_op has invalid setting"); } @@ -1145,7 +958,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 Block2ETileMap, ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch, - has_main_loop>; + has_main_loop, + ElementOp>; return launch_and_time_kernel( stream_config, @@ -1166,10 +980,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], arg.block_2_etile_map_container_[i], arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_); + arg.compute_ptr_offset_of_n_, + arg.k_batch_); }; - if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK, arg.k_batch_)) { ave_time += launch_kernel(integral_constant{}); } @@ -1182,678 +997,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 return ave_time; } - float RunGemmV3(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - float ave_time = 0; - - const ADataType* p_a_grid = arg.p_a_grid_; - const BDataType* p_b_grid = arg.p_b_grid_; - EDataType* p_e_grid = arg.p_e_grid_; - - if constexpr(is_NGCHW_NGKHW() || - is_NGCDHW_NGKDHW()) - { - p_a_grid = type_convert(arg.p_workspace_); - p_e_grid = - type_convert(arg.p_workspace_) + - (arg.GetWorkspaceATensorSizeBytes() + arg.GetWorkspaceBTensorSizeBytes()) / - sizeof(EDataType); - } - - if constexpr(is_NGCHW_GKCYX_NGKHW() || - is_NGCDHW_GKCZYX_NGKDHW()) - { - p_b_grid = type_convert(arg.p_workspace_) + - arg.GetWorkspaceATensorSizeBytes() / sizeof(BDataType); - } - - constexpr index_t minimum_occupancy = - BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave ? 1 : 2; - - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) - { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); - - const auto num_k_per_block = - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(Number<0>{}) / arg.k_batch_; - - // gdy is for the kbatch and num_workgrups_per_Conv_N - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( - GemmM, GemmN, arg.k_batch_ * arg.num_workgroups_per_Conv_N_, arg.num_group_); - - index_t k_grain = arg.k_batch_ * KPerBlock; - index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; - const bool has_main_k_block_loop = - GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - typename GridwiseGemm::Argument gemm_arg{ - p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - typename GridwiseGemm::Argument gemm_arg_ = gemm_arg; - ck::utility::RotatingMemWrapper - rotating_mem(gemm_arg_, - stream_config.rotating_count, - gemm_arg_.M * gemm_arg_.K * sizeof(ADataType), - gemm_arg_.K * gemm_arg_.N * sizeof(BDataType)); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - }; - - ave_time += ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg_, - arg.a_grid_desc_ak0_m_ak1_container_[i], - arg.b_grid_desc_bk0_n_bk1_container_[i], - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - num_k_per_block); - } - else - { - ave_time += launch_and_time_kernel( - stream_config, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - gemm_arg, - arg.a_grid_desc_ak0_m_ak1_container_[i], - arg.b_grid_desc_bk0_n_bk1_container_[i], - arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i], - arg.compute_ptr_offset_of_batch_, - arg.compute_ptr_offset_of_n_, - num_k_per_block); - } - }; - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - } - else - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - // Tail number could be One to Seven - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::One) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Two) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Two>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Three>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Six) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp:: - EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); - } - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::One>; - Run(kernel); - } - else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Full) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Full>; - Run(kernel); - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Two) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Two>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Three) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Three>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Four) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Four>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Five) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Five>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Six) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Six>; - Run(kernel); - } - } - - if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Seven) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Seven>; - Run(kernel); - } - } - } - } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Odd) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3_2lds< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == - TailNumber::Odd) - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(gemm_arg.KBatch > 1) - { - if constexpr(IsSplitKSupported) - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - } - else - { - const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< - GridwiseGemm, - DeviceOp::AGridDesc_AK0_M_AK1, - DeviceOp::BGridDesc_BK0_N_BK1, - DeviceOp::EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock, - ComputePtrOffsetOfStridedBatch, - ComputePtrOffsetOfStridedBatch, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy>; - Run(kernel); - } - } - } - } - return ave_time; - } - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { float ave_time = 0; @@ -1940,14 +1083,17 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 static_cast(arg.compute_ptr_offset_of_n_.BatchStrideA_)}, std::array{0}); } - - if constexpr(isMultiD) + if(arg.k_batch_ > 1) { - ave_time += RunMultiDGemm(arg, stream_config); + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } } else { - ave_time += RunGemmV3(arg, stream_config); + ave_time += RunMultiDGemm(arg, stream_config); } // Transpose from NHWGC to NGCHW @@ -2031,29 +1177,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; - if constexpr(!isMultiD) - { - for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) - { - const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); - const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); - const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * - arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); - - typename GridwiseGemm::Argument gemm_arg{ - nullptr, nullptr, nullptr, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; - - const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / AK1); - if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) - { - if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) - { - return false; - } - } - } - } - // Specifialization if constexpr(ConvBackwardDataSpecialization == ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) @@ -2156,16 +1279,14 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 // Gridwise GEMM size for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) { - if constexpr(isMultiD) + if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], + arg.b_grid_desc_n_k_container_[i], + arg.ds_grid_desc_m_n_container_[i], + arg.e_grid_desc_m_n_container_[i], + arg.block_2_etile_map_container_[i], + arg.k_batch_)) { - if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_m_k_container_[i], - arg.b_grid_desc_n_k_container_[i], - arg.ds_grid_desc_m_n_container_[i], - arg.e_grid_desc_m_n_container_[i], - arg.block_2_etile_map_container_[i])) - { - return false; - } + return false; } } @@ -2322,17 +1443,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 { auto str = std::stringstream(); - std::map BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - // clang-format off str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1" << "<" @@ -2350,11 +1460,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 << ABlockTransferSrcScalarPerVector << ", " << BBlockTransferSrcScalarPerVector << ", " << CShuffleMXdlPerWavePerShuffle << ", " - << CShuffleNXdlPerWavePerShuffle << ", " - << "BlkGemmPipelineScheduler: " - << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " - << "BlkGemmPipelineVersion: " - << BlkGemmPipelineVersionToString[BlkGemmPipelineVer]; + << CShuffleNXdlPerWavePerShuffle; if constexpr(is_NGCHW_NGKHW() || is_NGCDHW_NGKDHW()) { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index c0148c3b9c..27da1d91a3 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -179,7 +179,7 @@ __global__ void const long_index_t a_n_offset = amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx)); - GridwiseGemm::template Run( + GridwiseGemm::template Run( p_as_grid + a_group_offset + a_n_offset, p_bs_grid + b_group_offset, p_ds_grid_grp, @@ -434,7 +434,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle using GemmADataType = std::conditional_t, ADataType>; using GemmBDataType = std::conditional_t, BDataType>; -#define GridwiseGemmTemplateParameters \ +#define GridwiseGemmMultiABDTemplateParameters \ GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ @@ -450,11 +450,27 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ BComputeDataType + +#define GridwiseGemmTemplateParameters \ + GemmADataType, GemmBDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, \ + EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ + CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ + BComputeDataType // Use appropriate gridwise gemm - using GridwiseGemm = - std::conditional_t, - GridwiseGemmMultipleD_xdl_cshuffle>; + using GridwiseGemm = std::conditional_t< + isMultiA || isMultiB, + GridwiseGemmMultipleABD_xdl_cshuffle, + GridwiseGemmMultipleD_xdl_cshuffle>; // If ADataTypes or BDataTypes is tuple, user has to pass std::array with pointers. using APointers = diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index 3c34d77cc9..94a4e0da4c 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -89,7 +89,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset, gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset, Tuple<>{}, @@ -350,16 +350,15 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor #define GridwiseGemmTemplateParameters \ ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \ AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \ - InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \ - KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \ - ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \ - ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \ - ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \ - ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \ - BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \ - BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \ - BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \ - CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ + NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \ + NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \ + ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \ + ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \ + ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \ + BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \ + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \ + BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \ + BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \ CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \ CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \ AComputeDataType diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp index aa70a24fc1..cbee4e09f4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp @@ -65,7 +65,7 @@ __global__ void group_id = index_t((left + right) / 2); } - GridwiseGemm::template Run( + GridwiseGemm::template Run( gemm_desc_ptr[group_id].a_ptr_, gemm_desc_ptr[group_id].b_ptr_, gemm_desc_ptr[group_id].ds_ptr_, @@ -242,7 +242,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); @@ -550,6 +554,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle return; } + const index_t num_k_per_block = + __builtin_amdgcn_readfirstlane(a_grid_desc_ak0_m_ak1.GetLength(I0) / k_batch); + // HACK: this force m/n_block_data_idx_on_grid into SGPR const index_t m_block_data_idx_on_grid = __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock); @@ -591,7 +598,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( a_grid_desc_ak0_m_ak1, - make_multi_index(0, m_block_data_idx_on_grid, 0), + make_multi_index(num_k_per_block * k_idx, m_block_data_idx_on_grid, 0), a_element_op, a_block_desc_ak0_m_ak1, make_multi_index(0, 0, 0), @@ -622,7 +629,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle true, NumGemmKPrefetchStage>( b_grid_desc_bk0_n_bk1, - make_multi_index(0, n_block_data_idx_on_grid, 0), + make_multi_index(num_k_per_block * k_idx, n_block_data_idx_on_grid, 0), b_element_op, b_block_desc_bk0_n_bk1, make_multi_index(0, 0, 0), @@ -688,7 +695,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / - KPerBlock); + (KPerBlock * k_batch)); gridwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, @@ -943,6 +950,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle } template (p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } template (p_a_grid, - p_b_grid, - p_ds_grid, - p_e_grid, - p_shared, - a_element_op, - b_element_op, - cde_element_op, - a_grid_desc_ak0_m_ak1, - b_grid_desc_bk0_n_bk1, - ds_grid_desc_mblock_mperblock_nblock_nperblock, - e_grid_desc_mblock_mperblock_nblock_nperblock, - block_2_etile_map); + Run( + p_a_grid, + p_b_grid, + p_ds_grid, + p_e_grid, + p_shared, + a_element_op, + b_element_op, + cde_element_op, + a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + ds_grid_desc_mblock_mperblock_nblock_nperblock, + e_grid_desc_mblock_mperblock_nblock_nperblock, + block_2_etile_map); } };