diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp index 9a208b99c1..ed998a8ecb 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_xdl_producer_consumer_cshuffle.hpp @@ -1,14 +1,15 @@ #pragma once + #include #include -#include "device.hpp" -#include "device_gemm.hpp" -#include "common_header.hpp" -#include "tensor_layout.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "gridwise_gemm_xdl_producer_consumer_cshuffle.hpp" -#include "gemm_specialization.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp" namespace ck { namespace tensor_operation { @@ -437,7 +438,7 @@ struct DeviceGemm_Xdl_ProducerConsumer_CShuffle { using Argument = DeviceOp::Argument; - float Run(const Argument& arg, int nrepeat = 1) + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { #if 0 { @@ -484,42 +485,22 @@ struct DeviceGemm_Xdl_ProducerConsumer_CShuffle typename GridwiseGemm::DefaultBlock2CTileMap, true>; - if(nrepeat == 0) - { - launch_kernel(kernel, - dim3(grid_size), - dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } - else - { - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); } else { @@ -536,51 +517,32 @@ struct DeviceGemm_Xdl_ProducerConsumer_CShuffle typename GridwiseGemm::DefaultBlock2CTileMap, false>; - if(nrepeat == 0) - { - launch_kernel(kernel, - dim3(grid_size), - dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } - else - { - ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize), - 0, - arg.p_a_grid_, - arg.p_b_grid_, - arg.p_c_grid_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_, - arg.a_grid_desc_ak0_m_ak1_, - arg.b_grid_desc_bk0_n_bk1_, - arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, - arg.block_2_ctile_map_); - } + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(grid_size), + dim3(ABBlockTransferThreadGroupSize + BlockGemmThreadGroupSize), + 0, + arg.p_a_grid_, + arg.p_b_grid_, + arg.p_c_grid_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_, + arg.a_grid_desc_ak0_m_ak1_, + arg.b_grid_desc_bk0_n_bk1_, + arg.c_grid_desc_mblock_mperblock_nblock_nperblock_, + arg.block_2_ctile_map_); } return ave_time; } // polymorphic - float Run(const BaseArgument* p_arg, int nrepeat = 1) override + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override { - return Run(*dynamic_cast(p_arg), nrepeat); + return Run(*dynamic_cast(p_arg), stream_config); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp index 6b04881715..2374f6e716 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp @@ -1,6 +1,6 @@ #pragma once -#include "common_header.hpp" +#include "ck/utility/common_header.hpp" namespace ck { diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp index f5cf8de41a..915207844b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_producer_consumer_cshuffle.hpp @@ -1,13 +1,14 @@ #pragma once -#include "common_header.hpp" -#include "multi_index_transform_helper.hpp" -#include "tensor_descriptor.hpp" -#include "tensor_descriptor_helper.hpp" -#include "blockwise_gemm_xdlops.hpp" -#include "thread_group_tensor_slice_transfer_v4r1.hpp" -#include "thread_group_tensor_slice_transfer_v6r1.hpp" -#include "threadwise_tensor_slice_transfer.hpp" -#include "gridwise_gemm_pipeline_producer_consumer.hpp" + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/multi_index_transform_helper.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_producer_consumer.hpp" +#include "ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp" +#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" namespace ck { @@ -455,7 +456,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_producer_consumer_cshuffle math::lcm(AK1, BK1), MfmaSelector::selected_mfma.k_per_blk); auto blockwise_gemm = - BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1