diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp index 4872679754..403959afff 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp @@ -144,39 +144,18 @@ struct DeviceGroupedConvBwdWeight_Explicit end(e_g_k_c_xs_lengths), begin(filter_spatial_lengths_)); - if constexpr(IsTwoStageNeeded) + if(split_k < 0) { - if(split_k < 0) - { - const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = - DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); - const index_t grid_size = gdx * gdy * gdz; - split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); - } - else - { - split_k_ = split_k; - } + const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); + index_t gdx, gdy, gdz; + std::tie(gdx, gdy, gdz) = + DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); + const index_t grid_size = gdx * gdy * gdz; + split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); } else { -#if !DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if(split_k < 0) - { - const auto max_occupancy = DeviceGemmV3Op::GetMaxOccupancy(); - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = - DeviceGemmV3Op::GridwiseGemm::CalculateGridSize(M, N, BatchSize); - const index_t grid_size = gdx * gdy * gdz; - split_k_ = get_best_occupancy_k_batch_value(max_occupancy, grid_size); - } - else -#endif - { - split_k_ = split_k; - } + split_k_ = split_k; } if constexpr(IsTwoStageNeeded) @@ -339,16 +318,6 @@ struct DeviceGroupedConvBwdWeight_Explicit static bool IsSupportedArgument(const Argument& arg) { -#if DISABLE_SPLIT_K_AUTODEDUCE_FOR_ONE_STAGE_KERNELS - if constexpr(!IsTwoStageNeeded) - { - if(arg.split_k_ < 0) - { - return false; - } - } -#endif - if constexpr(NDimSpatial == 2) { if constexpr(!is_NHWGC_GKYXC_NHWGK()) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp index 26b4da349b..676003b96d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp @@ -322,7 +322,8 @@ struct DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, false, - false>; + false, + true>; static constexpr auto MakeElementwiseInputSequence() { diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp index 049118597a..4e5668e957 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp @@ -374,7 +374,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, false, - false>; + false, + true>; using Block2TileMapElementwise = BlockToCTileMap_M00_N0_M01Adapt; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp index aa4ff80719..9f892a2fc4 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp @@ -396,7 +396,8 @@ struct DeviceGroupedConvBwdWeight_Wmma_CShuffleV3 ComputeTypeA, ComputeTypeB, false, - false>; + false, + true>; // Argument using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock = diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 465952e285..0d0cd344d8 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -289,7 +289,8 @@ struct ABTransferThreadTiles __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, BlockDescriptor& block_descriptor, ABElementwiseOperation& ab_element_op, - const index_t block_mn_id) + const index_t block_mn_id, + const index_t k_id) { constexpr index_t NumABTensor = ABsDataType::Size(); const index_t mn_block_data_idx_on_grid = @@ -298,7 +299,7 @@ struct ABTransferThreadTiles if constexpr(NumABTensor > 1) { const auto idx_as_block_begin = generate_tuple( - [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); }, + [&](auto) { return make_multi_index(k_id, mn_block_data_idx_on_grid, 0); }, Number{}); return ThreadGroupTensorSliceTransfer_v7r2< @@ -351,7 +352,7 @@ struct ABTransferThreadTiles ABThreadTransferSrcResetCoordinateAfterRun, true, GlobalBufferNum>(grid_descriptor[I0], - make_multi_index(0, mn_block_data_idx_on_grid, 0), + make_multi_index(k_id, mn_block_data_idx_on_grid, 0), ab_element_op, block_descriptor, make_multi_index(0, 0, 0), diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp index 68476ef3bf..a877a7b389 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp @@ -264,7 +264,8 @@ struct ABTransferWaveTiles __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor, BlockDescriptor& block_descriptor, ABElementwiseOperation& ab_element_op, - const index_t block_mn_id) + const index_t block_mn_id, + const index_t) { // Note: GlobalBufferNum is currently not used but it will be needed // once we add other pipelines. It is currently needed only for diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index f27963d8f5..1d88b8c352 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -176,7 +176,7 @@ template + bool ForceThreadTileTransfer = false> struct GridwiseGemm_wmma_cshuffle_v3 : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, @@ -327,8 +327,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 using typename Base::AsGridPointer; using typename Base::BsGridPointer; using typename Base::DsGridPointer; - using AsDataType_ = AsDataType; - using BsDataType_ = BsDataType; struct Problem { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp index 1b8a8ef09e..9a9f9dc560 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -221,8 +221,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale using typename Base::AsGridPointer; using typename Base::BsGridPointer; using typename Base::DsGridPointer; - using AsDataType_ = AsDataType; - using BsDataType_ = BsDataType; struct Problem { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 57aeaeb4cb..1b409ed315 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -3,11 +3,6 @@ #pragma once -#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC) -#include -#include -#endif - #include "ck/utility/env.hpp" #include "ck/utility/common_header.hpp" #include "ck/tensor_description/multi_index_transform_helper.hpp" @@ -789,27 +784,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base { if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages) { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "Pipeline validation failed: num_k_loop (" << num_k_loop - << ") <= PrefetchStages (" << BlockwiseGemmPipe::PrefetchStages - << ") for pipeline version != v1." << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; - } - return false; - } - } - - if constexpr(is_same, int8_t>::value) - { - if(karg.KBatch > 1) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "int8_t does not support KBatch > 1. KBatch: " << karg.KBatch - << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ - << std::endl; - } return false; } } @@ -873,6 +847,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base } } + // Note: arguments k_batch and k_id should be set if splitk is used + // with implicit gemm (no pointer shift but shift using tensor descriptors) template ( - as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id); + as_grid_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1, a_element_op, block_m_id, k_id); // B matrix blockwise copy auto b_blockwise_copy = @@ -951,7 +927,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base BsDataType, BElementwiseOperation, BlockwiseGemmPipe::GlobalBufferNum>( - bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id); + bs_grid_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1, b_element_op, block_n_id, k_id); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -976,7 +952,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer(); const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane( - ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / KPerBlock); + ATransfer::GetKDimension(as_grid_desc_ak0_m_ak1[I0]) / (KPerBlock * k_batch)); blockwise_gemm_pipeline.template Run( get_first_element_workaround(as_grid_desc_ak0_m_ak1),