diff --git a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp index 77a9fe4271..df8351602b 100644 --- a/example/ck_tile/03_gemm/gemm_basic_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_basic_invoker.hpp @@ -69,107 +69,88 @@ struct BasicInvoker using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - CodegenPipelineProblem::TransposeC, - memory_operation>>; + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - // ToDo: Will add the codegen part to test different pipeline policies in GEMM. - // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp index c312a53c2a..d2460193d8 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_invoker.hpp @@ -72,160 +72,144 @@ struct SplitKTwoStageInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmKernel = ck_tile::GemmKernel; - using GemmKernel = ck_tile::GemmKernel; + ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); + ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); + auto c_ptr = ws_args.c_ptr; + ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); + auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); - ck_tile::DeviceMem ws_m_n_dev_buf(args.M * args.N * sizeof(WorkspaceType)); - ck_tile::GemmHostArgs ws_args = ck_tile::GemmHostArgs(args); - auto c_ptr = ws_args.c_ptr; - ws_args.c_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - auto gemm_kargs = GemmKernel::MakeKernelArgs(ws_args); + const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) + : GemmKernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = GemmKernel::BlockSize(); - const dim3 grids = Persistent ? GemmKernel::MaxOccupancyGridSize(s) - : GemmKernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = GemmKernel::BlockSize(); + if(!GemmKernel::IsSupportedArgument(gemm_kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!GemmKernel::IsSupportedArgument(gemm_kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = {args.M, args.N}; - ck_tile::index_t total_elements = 1; - std::vector shape = {args.M, args.N}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - constexpr ck_tile::index_t kBlockPerCu = 1; + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); + auto input_size = ck_tile::make_tuple(args.M, args.N); - auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.c_ptr)); - auto input_size = ck_tile::make_tuple(args.M, args.N); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << GemmKernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; - - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - gemm_kargs.as_ptr[0], - gemm_kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel( - GemmKernel{}, grids, blocks, 0, gemm_kargs), - ck_tile::make_kernel(ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(args.N, 1), // Input Stride - ck_tile::make_tuple(args.N, 1), // Output Stride - input_tensors, - static_cast(c_ptr))); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + ws_args.c_ptr, 0, args.M * args.N * sizeof(WorkspaceType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + gemm_kargs.as_ptr[0], + gemm_kargs.bs_ptr[0], + s.rotating_count_, + size_a_buffer, + size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel( + GemmKernel{}, grids, blocks, 0, gemm_kargs), + ck_tile::make_kernel(ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(args.N, 1), // Input Stride + ck_tile::make_tuple(args.N, 1), // Output Stride + input_tensors, + static_cast(c_ptr))); } }; diff --git a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp index c06dc457c9..64305b85cf 100644 --- a/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp +++ b/example/ck_tile/03_gemm/gemm_splitk_two_stage_reduce.cpp @@ -160,110 +160,101 @@ float gemm_stage1(const GemmSplitKHostArgs& args, const ck_tile::stream_config& args.stride_E); constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&]() { - // use SET operation since each K-split writes to separate memory - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = + ck_tile::CShuffleEpilogue>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(base_args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(base_args); + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + const dim3 blocks = Kernel::BlockSize(); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Stage 1 - Launching GEMM kernel: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); - ck_tile::RotatingMemWrapper rotating_mem( - kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - return ck_tile::launch_kernel_time_mask( - s, - run_flush_cache, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - else - { - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - } - }; - - return Run(); + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + return ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } } /** diff --git a/example/ck_tile/03_gemm/gemm_utils.hpp b/example/ck_tile/03_gemm/gemm_utils.hpp index f79494a478..8eff0e7469 100644 --- a/example/ck_tile/03_gemm/gemm_utils.hpp +++ b/example/ck_tile/03_gemm/gemm_utils.hpp @@ -460,12 +460,6 @@ inline auto create_args() return arg_parser; } -// Type aliases for memory operation integral constants -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - // host API template ::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" - << std::endl; - } - float ave_time = 0.f; - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - ck_tile::RotatingMemWrapper rotating_mem(kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck_tile::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - ave_time = - ck_tile::launch_kernel_time_mask(s, - run_flush_cache, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - else - { - ave_time = ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kargs)); - } - return ave_time; - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - return Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - throw std::runtime_error("split-k is not supported yet!"); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << ", kBlockPerCu: {" << GemmConfig::kBlockPerCu << "}" << std::endl; + } + float ave_time = 0.f; + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_time_mask( + s, + run_flush_cache, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = ck_tile::launch_kernel( + s, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; } }; diff --git a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp index 4a83a2c4ab..fb89e6b4cc 100644 --- a/example/ck_tile/03_gemm/universal_gemm_invoker.hpp +++ b/example/ck_tile/03_gemm/universal_gemm_invoker.hpp @@ -60,112 +60,94 @@ struct UniversalInvoker using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GemmKernel; - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) - : Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Persistent ? Kernel::MaxOccupancyGridSize(s) + : Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - // Declare rotating_mem_ptr here so it stays in scope until it is needed - std::unique_ptr> rotating_mem_ptr; - std::function preprocess; + // Declare rotating_mem_ptr here so it stays in scope until it is needed + std::unique_ptr> rotating_mem_ptr; + std::function preprocess; - auto clear_gemm_output = [&]() { - if(args.k_batch > 1) - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); - }; - - if(s.flush_cache_) - { - std::cout << "Flushing cache..." << std::endl; - - ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( - args.M, args.K, args.stride_A, is_row_major(ALayout{}))); - ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - args.K, args.N, args.stride_B, is_row_major(BLayout{}))); - - auto size_a_buffer = a_m.get_element_space_size_in_bytes(); - auto size_b_buffer = b_n.get_element_space_size_in_bytes(); - - rotating_mem_ptr = - std::make_unique>( - kargs.as_ptr[0], - kargs.bs_ptr[0], - s.rotating_count_, - size_a_buffer, - size_b_buffer); - rotating_mem_ptr->Print(); - - preprocess = [&]() { - ck_tile::flush_icache(); - rotating_mem_ptr->Next(); - clear_gemm_output(); - }; - } - else - { - preprocess = clear_gemm_output; - } - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto clear_gemm_output = [&]() { + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); }; - if(args.k_batch == 1) + if(s.flush_cache_) { - return Run(MemoryOpSet{}); + std::cout << "Flushing cache..." << std::endl; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes(); + auto size_b_buffer = b_n.get_element_space_size_in_bytes(); + + rotating_mem_ptr = std::make_unique>( + kargs.as_ptr[0], kargs.bs_ptr[0], s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem_ptr->Print(); + + preprocess = [&]() { + ck_tile::flush_icache(); + rotating_mem_ptr->Next(); + clear_gemm_output(); + }; } else { - return Run(MemoryOpAtomicAdd{}); + preprocess = clear_gemm_output; } + + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/16_batched_gemm/batched_gemm.cpp b/example/ck_tile/16_batched_gemm/batched_gemm.cpp index c7e37bc8a7..b68c30351d 100644 --- a/example/ck_tile/16_batched_gemm/batched_gemm.cpp +++ b/example/ck_tile/16_batched_gemm/batched_gemm.cpp @@ -78,63 +78,48 @@ float batched_gemm(const ck_tile::BatchedGemmHostArgs& args, const ck_tile::stre using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_batched_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp index 3ff3f2f10e..a24e4bc8ab 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.cpp @@ -62,71 +62,55 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -161,74 +144,55 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, BLayout, CLayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - if(!splitk) + if(s.log_level_ > 0) { - return ave_time = Run(ck_tile::integral_constant{}); - } - else - { - return ave_time = - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp index 67b411c1f0..462f11e405 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm.hpp @@ -328,5 +328,4 @@ template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk = false); + void* kargs_ptr); diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp index 060dd311b5..e5aefad8d1 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_multi_d.cpp @@ -61,72 +61,56 @@ float grouped_gemm_multi_d(const std::vector& gemm_d using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: { " + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -163,76 +146,55 @@ float grouped_gemm_multi_d_tileloop(const ck_tile::stream_config& s, BLayout, ELayout>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_multi_d_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp index 4a5be996c0..b4c10900d6 100644 --- a/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp +++ b/example/ck_tile/17_grouped_gemm/grouped_gemm_preshuffle.cpp @@ -65,70 +65,54 @@ float grouped_gemm(const std::vector& gemm_descs, using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Kernel arguments not supported!"); - } - - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(gemm_descs); - - HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Kernel arguments not supported!"); } - else + + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + + HIP_CHECK_ERROR(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template float grouped_gemm_tileloop(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -167,75 +150,53 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, GemmConfig::NumWaveGroups, GemmConfig::Preshuffle>; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, // DsDataType (empty for no D tensors) + AccDataType, + CDataType, + ck_tile::tuple<>, // DsLayout (empty for no D tensors) + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmPipeline = typename PipelineTypeTraits< - GemmConfig::Pipeline>::template GemmPipeline; - using GemmEpilogue = ck_tile::CShuffleEpilogue, // DsDataType (empty for no D tensors) - AccDataType, - CDataType, - ck_tile::tuple<>, // DsLayout (empty for no D tensors) - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ave_time = - ck_tile::launch_kernel(s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - - if(splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } - return ave_time; + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } #include "run_grouped_gemm_example.inc" diff --git a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp index 16352722e1..ea71abb213 100644 --- a/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp +++ b/example/ck_tile/17_grouped_gemm/quant_invoke_grouped_gemm_kernel.hpp @@ -72,10 +72,9 @@ float grouped_gemm(const std::vector& gemm_descs, float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || QuantMode == ck_tile::QuantType::BQuantGrouped; @@ -137,8 +136,7 @@ float grouped_gemm(const std::vector& gemm_descs, GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - float ave_time{0}; + constexpr auto scheduler = GemmConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = GemmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::BQuantGrouped; - constexpr bool UseGroupedQuant = QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; + using GemmPipeline = GemmQuantConfig::template GemmPipeline; - using GemmPipeline = - GemmQuantConfig::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + QuantGemmProblem::TransposeC>>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" - << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" - << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - }; - - return ave_time = Run(ck_tile::integral_constant{}); + return ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index 390a54644b..7a01b1dcea 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -79,8 +79,7 @@ float invoke_gemm(int n_warmup, // earlier stage. std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -109,7 +108,7 @@ float invoke_gemm(int n_warmup, ADataType, BDataType, AccDataType, - CDataType>(stream, group_count, kargs_ptr, splitk); + CDataType>(stream, group_count, kargs_ptr); } return ave_time; diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc index ac6ea99db3..4f2bebdf17 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_multi_d_example.inc @@ -95,8 +95,7 @@ float invoke_gemm(int n_warmup, else { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = args[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : args) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<1, 1, NumDTensor>{{arg.a_ptr}, @@ -119,18 +118,17 @@ float invoke_gemm(int n_warmup, kargs.size() * sizeof(ck_tile::GemmTransKernelArg), hipMemcpyHostToDevice, stream.stream_id_)); - ave_time = - grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr, splitk); + ave_time = grouped_gemm_multi_d_tileloop(stream, group_count, kargs_ptr); } return ave_time; } diff --git a/example/ck_tile/18_flatmm/flatmm_basic.cpp b/example/ck_tile/18_flatmm/flatmm_basic.cpp index cd241a2be0..af46884a90 100644 --- a/example/ck_tile/18_flatmm/flatmm_basic.cpp +++ b/example/ck_tile/18_flatmm/flatmm_basic.cpp @@ -170,13 +170,10 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -282,23 +278,7 @@ float flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/grouped_flatmm.cpp b/example/ck_tile/18_flatmm/grouped_flatmm.cpp index da85c95dae..780a21ba14 100644 --- a/example/ck_tile/18_flatmm/grouped_flatmm.cpp +++ b/example/ck_tile/18_flatmm/grouped_flatmm.cpp @@ -113,13 +113,10 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem>; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. @@ -216,23 +212,7 @@ float grouped_flatmm(const KernelArguments& args, const ck_tile::stream_config& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp index fe7fe4c5d1..708e8a683e 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/a16w4_moe_flatmm.cpp @@ -113,13 +113,10 @@ float a16w4_moe_gemm(const MoeFlatmmHostArgs& args, const ck_tile::stream_config const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = std::conditional_t{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp index 2b6dbace36..f9f8c0cec7 100644 --- a/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp +++ b/example/ck_tile/18_flatmm/mixed_prec/mixed_prec_flatmm.cpp @@ -89,13 +89,10 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -128,7 +125,6 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC @@ -201,23 +197,7 @@ float mixed_prec_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/moe_flatmm.cpp b/example/ck_tile/18_flatmm/moe_flatmm.cpp index 96b9ae29a4..4cca953066 100644 --- a/example/ck_tile/18_flatmm/moe_flatmm.cpp +++ b/example/ck_tile/18_flatmm/moe_flatmm.cpp @@ -144,15 +144,11 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - float ave_time{0}; - const auto Run = [&](const auto has_hot_loop_, - const auto tail_number_, - const auto memory_operation_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = FlatmmConfig::Scheduler; using CodegenPipelineProblem = ck_tile::FlatmmPipelineProblem& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, CodegenPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, 1, @@ -261,37 +256,20 @@ float moe_gemm(const ck_tile::MoeFlatmmHostArgs& args, args.NumTokens * args.TopK * outputN * sizeof(CDataType), s.stream_id_)); }; - ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( s, run_flush_cache, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } else { - ave_time = ck_tile::launch_kernel( + return ck_tile::launch_kernel( s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } - return ave_time; }; - const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { - if(args.k_batch == 1) - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - else - { - Run(has_hot_loop_, - tail_number_, - ck_tile::integral_constant{}); - } - }; - BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + float ave_time = BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); return ave_time; } diff --git a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp index f177ef04ca..01128f8fe8 100644 --- a/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp +++ b/example/ck_tile/18_flatmm/mxgemm/mx_flatmm_instance.hpp @@ -61,8 +61,7 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, "mixed_prec_flatmm requires ADataType is a wider type than BDataType"); constexpr auto scheduler = FlatmmConfig::Scheduler; - constexpr auto memory_operation = - Splitk ? ck_tile::memory_operation_enum::atomic_add : ck_tile::memory_operation_enum::set; + ck_tile::ignore = Splitk; constexpr int BlockedXDLN_PerWarp = 2; // determined by scale shuffle pattern @@ -98,7 +97,6 @@ float mx_flatmm_calc(const ck_tile::ScaleFlatmmHostArgs& args, FlatmmConfig::N_Warp_Tile, FlatmmConfig::K_Warp_Tile, MXPipelineProblem::TransposeC, - memory_operation, FlatmmConfig::NumWaveGroups, false, // FixedVectorSize 1, // VectorSizeC diff --git a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp index 9e2bc3e3fb..1c56295f9f 100644 --- a/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp +++ b/example/ck_tile/19_gemm_multi_d/gemm_multi_d_fp16.cpp @@ -81,87 +81,45 @@ auto gemm_multi_d(const gemm_multi_d_kargs& args, const ck_tile::stream_config& using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - // Epilogue selection: set to true for chainer-based, false for standard - // CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue, - // Chainer-based epilogue - ck_tile::EpilogueChainer, - ck_tile::DefaultScheduleTag>>, - // Standard CShuffleEpilogue - ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>>; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_d_fp16_example.inc" diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp index d2663b033c..ca8573d6d2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_data_invoker.hpp @@ -59,94 +59,80 @@ struct GroupedConvolutionBackwardDataInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + InDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - InDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardDataKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - ck_tile::hip_check_error(hipMemsetAsync( - kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); - }; - - return ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + auto preprocess = [&]() { + ck_tile::hip_check_error(hipMemsetAsync( + kargs.in_ptr, 0, args.template GetInputByte(), s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + return ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp index afe43cd1c0..90874e6018 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_invoker.hpp @@ -59,104 +59,85 @@ struct GroupedConvolutionBackwardWeightInvoker ConvConfig::NumWaveGroups>; constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - const auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } + + auto preprocess = [&]() { + if(args.k_batch > 1) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + ck_tile::hip_check_error(hipMemsetAsync( + kargs.wei_ptr, 0, args.template GetWeightByte(), s.stream_id_)); } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - { - ck_tile::hip_check_error( - hipMemsetAsync(kargs.wei_ptr, - 0, - args.template GetWeightByte(), - s.stream_id_)); - } - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return InvokerResult{ave_time, args.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp index ad5e8ae70f..c4d618a0bf 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_backward_weight_two_stage_invoker.hpp @@ -65,163 +65,143 @@ struct GroupedConvolutionBackwardWeightTwoStageInvoker constexpr auto scheduler = ConvConfig::Scheduler; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + OutDataType, + InDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + WeiDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - OutDataType, - InDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - WeiDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; - using Kernel = ck_tile::GroupedConvolutionBackwardWeightKernel; + const ck_tile::index_t spatial_lengths_accum = + std::accumulate(args.filter_spatial_lengths_.begin(), + args.filter_spatial_lengths_.end(), + 1, + std::multiplies()); + ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * + sizeof(WorkspaceDataType)); + ck_tile::GroupedConvBwdWeightHostArgs ws_args = ck_tile::GroupedConvBwdWeightHostArgs(args); + auto c_ptr = ws_args.wei_ptr; + ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const ck_tile::index_t spatial_lengths_accum = - std::accumulate(args.filter_spatial_lengths_.begin(), - args.filter_spatial_lengths_.end(), - 1, - std::multiplies()); - ck_tile::DeviceMem ws_m_n_dev_buf(args.G_ * args.K_ * args.C_ * spatial_lengths_accum * - sizeof(WorkspaceDataType)); - ck_tile::GroupedConvBwdWeightHostArgs ws_args = - ck_tile::GroupedConvBwdWeightHostArgs(args); - auto c_ptr = ws_args.wei_ptr; - ws_args.wei_ptr = ws_m_n_dev_buf.GetDeviceBuffer(); - const auto kargs = Kernel::MakeKernelArgs(ws_args); + const auto kargs = Kernel::MakeKernelArgs(ws_args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); + } - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } + using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; + using BlockTile = ck_tile::sequence<2048>; + using BlockWarps = ck_tile::sequence<8>; + using WarpTile = ck_tile::sequence<64>; - using XElementwiseOperation = ck_tile::element_wise::UnaryConvert; - using BlockTile = ck_tile::sequence<2048>; - using BlockWarps = ck_tile::sequence<8>; - using WarpTile = ck_tile::sequence<64>; + using ElementwiseShape = + ck_tile::ElementWiseShape; + using Problem = ck_tile::ElementWisePipelineProblem; + using ElementwiseKernel = + ck_tile::ElementWiseKernel; - using ElementwiseShape = - ck_tile::ElementWiseShape; - using Problem = ck_tile::ElementWisePipelineProblem; - using ElementwiseKernel = - ck_tile::ElementWiseKernel; + ck_tile::index_t total_elements = 1; + std::vector shape = { + static_cast(args.G_ * args.K_), + static_cast(args.C_ * spatial_lengths_accum)}; - ck_tile::index_t total_elements = 1; - std::vector shape = { - static_cast(args.G_ * args.K_), - static_cast(args.C_ * spatial_lengths_accum)}; + for(auto d : shape) + total_elements *= d; - for(auto d : shape) - total_elements *= d; + const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); - const ck_tile::index_t kBlockSize = ElementwiseKernel::BlockSize(); + constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); + ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block; - constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{}); - ck_tile::index_t kGridSize = - (total_elements + elements_per_block - 1) / elements_per_block; + auto input_tensors = ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); + auto input_size = ck_tile::make_tuple(shape[0], shape[1]); - auto input_tensors = - ck_tile::make_tuple(static_cast(ws_args.wei_ptr)); - auto input_size = ck_tile::make_tuple(shape[0], shape[1]); + // Check if the kernel configuration is supported + if(!ElementwiseKernel::IsSupportedArgument(input_size)) + { + throw std::runtime_error( + "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); + } - // Check if the kernel configuration is supported - if(!ElementwiseKernel::IsSupportedArgument(input_size)) - { - throw std::runtime_error( - "Wrong! Elementwise arguments not supported! Skipping gemm!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - auto preprocess = [&]() { - if(kargs.k_batch > 1) - ck_tile::hip_check_error( - hipMemsetAsync(ws_args.wei_ptr, - 0, - shape[0] * shape[1] * sizeof(WorkspaceDataType), - s.stream_id_)); - }; - - const auto ave_time = ck_tile::launch_kernel_time_mask( - s, - preprocess, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), - ck_tile::make_kernel( - ElementwiseKernel{}, - kGridSize, - kBlockSize, - 0, - input_size, - ck_tile::make_tuple(shape[1], 1), // Input Stride - ck_tile::make_tuple(shape[1], 1), // Output Stride - input_tensors, - static_cast(c_ptr))); - - const auto split_k = kargs.k_batch; - - return InvokerResult{ave_time, split_k}; + auto preprocess = [&]() { + if(args.k_batch > 1) + ck_tile::hip_check_error( + hipMemsetAsync(ws_args.wei_ptr, + 0, + shape[0] * shape[1] * sizeof(WorkspaceDataType), + s.stream_id_)); }; - if(args.k_batch == 1) - { - return Run(MemoryOpSet{}); - } - else - { - return Run(MemoryOpAtomicAdd{}); - } + float ave_time = ck_tile::launch_kernel_time_mask( + s, + preprocess, + ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs), + ck_tile::make_kernel( + ElementwiseKernel{}, + kGridSize, + kBlockSize, + 0, + input_size, + ck_tile::make_tuple(shape[1], 1), // Input Stride + ck_tile::make_tuple(shape[1], 1), // Output Stride + input_tensors, + static_cast(c_ptr))); + return InvokerResult{ave_time, kargs.k_batch}; } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp index 82541bb593..c94466aeb2 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_invoker.hpp @@ -70,91 +70,74 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Regular Convolution: Simple, no split-image // ===================================================================== - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< - InDataType, - WeiDataType, - AccDataType, - GemmShape, - GemmUniversalTraits, - scheduler, - ck_tile::element_wise::PassThrough, - ck_tile::element_wise::PassThrough, - OutDataType, - GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, - GroupedConvTraitsType::VectorSizeA, - GroupedConvTraitsType::VectorSizeB>; + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem< + InDataType, + WeiDataType, + AccDataType, + GemmShape, + GemmUniversalTraits, + scheduler, + ck_tile::element_wise::PassThrough, + ck_tile::element_wise::PassThrough, + OutDataType, + GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, + GroupedConvTraitsType::VectorSizeA, + GroupedConvTraitsType::VectorSizeB>; - using GemmPipeline = typename PipelineTypeTraits< - ConvConfig::Pipeline>::template GemmPipeline; + using GemmPipeline = typename PipelineTypeTraits< + ConvConfig::Pipeline>::template GemmPipeline; - using ConvEpilogue = ck_tile::CShuffleEpilogue>; + using ConvEpilogue = ck_tile::CShuffleEpilogue>; - using Kernel = ck_tile::GroupedConvolutionForwardKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using Kernel = ck_tile::GroupedConvolutionForwardKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << '\n' - << "Vector size A: " << GemmPipeline::GetVectorSizeA() - << ", Vector size B: " << GemmPipeline::GetVectorSizeB() - << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - // ===================================================================== - // Split-K dispatch - // ===================================================================== - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(MemoryOpSet{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping conv!\n"); } - else + + if(s.log_level_ > 0) { - return Run(MemoryOpAtomicAdd{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << UniversalGemmProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << '\n' + << "Vector size A: " << GemmPipeline::GetVectorSizeA() + << ", Vector size B: " << GemmPipeline::GetVectorSizeB() + << ", Vector size C: " << ConvEpilogue::GetVectorSizeC() << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp index 4261385a84..5dec340668 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_forward_large_tensor_invoker.hpp @@ -213,8 +213,7 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== // Kernel launch lambda: Uses EnableSplitImage based on layout support // ===================================================================== - const auto Run = [&](const auto memory_operation_, const auto enable_split_image_) { - constexpr auto memory_operation = memory_operation_.value; + const auto Run = [&](const auto enable_split_image_) { constexpr bool EnableSplitImage = enable_split_image_.value; using GroupedConvTraitsType = std::conditional_t>; @@ -332,17 +330,11 @@ struct GroupedConvolutionForwardInvoker // ===================================================================== if(use_split_image) { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } else { - if(args.k_batch == 1) - return Run(MemoryOpSet{}, ck_tile::bool_constant{}); - else - return Run(MemoryOpAtomicAdd{}, ck_tile::bool_constant{}); + return Run(ck_tile::bool_constant{}); } } }; diff --git a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp index 63dd54dcae..a78a880815 100644 --- a/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp +++ b/example/ck_tile/20_grouped_convolution/grouped_convolution_utils.hpp @@ -13,11 +13,6 @@ #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "conv_configs.hpp" -using MemoryOpSet = - std::integral_constant; -using MemoryOpAtomicAdd = std::integral_constant; - template auto calculate_rtol_atol(const ck_tile::index_t GemmK, const ck_tile::index_t kbatch, diff --git a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp index acb9126d65..9202bf9d98 100644 --- a/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp +++ b/example/ck_tile/22_gemm_multi_abd/gemm_multi_abd_fp16.cpp @@ -85,60 +85,44 @@ auto gemm_multi_abd(const gemm_multi_abd_kargs& args, const ck_tile::stream_conf using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< UniversalGemmProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y - << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y - << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - return Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - return Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " + << blocks.z << "}" << std::endl; } + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } #include "run_gemm_multi_abd_fp16_example.inc" diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 47a22cdcba..d8988be7b0 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -173,77 +173,30 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str printf( "TiledPermuteN: %d (QuantGroupSize::kN=%d)\n", TiledPermuteN, BQuantGroupSize::kN); } - - // Epilogue selection: use chainer for RowCol/Tensor quant, standard for others - // Toggle to switch between chainer-based and standard CShuffleEpilogue - constexpr bool UseChainerEpilogue = true; - - // Define the schedule tag based on quant mode - using ScheduleTag = - std::conditional_t>; - - using GemmEpilogue = std::conditional_t< - UseChainerEpilogue && (QuantMode == ck_tile::QuantType::RowColQuant || - QuantMode == ck_tile::QuantType::TensorQuant), - // Chainer-based epilogue for RowCol/Tensor quant modes - ck_tile::EpilogueChainer, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>, - ScheduleTag>>, - // Standard CShuffleEpilogue for other modes - ck_tile::CShuffleEpilogue, typename TypeConfig::ADataType, - std::conditional_t< - std::is_same_v, - typename TypeConfig::ADataType, - typename TypeConfig::BDataType>, - ck_tile::tuple<>, - typename TypeConfig::AccDataType, - typename TypeConfig::CDataType, - ck_tile::tuple<>, - CLayout, - CDEElementWise, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - GemmConfig::M_Warp, - GemmConfig::N_Warp, - GemmConfig::M_Warp_Tile, - GemmConfig::N_Warp_Tile, - GemmConfig::K_Warp_Tile, - transpose_c, - ck_tile::memory_operation_enum::set, - 1, - false, - 1, - TiledPermuteN>>>; - + typename TypeConfig::BDataType>, + ck_tile::tuple<>, + typename TypeConfig::AccDataType, + typename TypeConfig::CDataType, + ck_tile::tuple<>, + CLayout, + CDEElementWise, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + GemmConfig::M_Warp, + GemmConfig::N_Warp, + GemmConfig::M_Warp_Tile, + GemmConfig::N_Warp_Tile, + GemmConfig::K_Warp_Tile, + transpose_c, + 1, + false, + 1, + TiledPermuteN>>; using Kernel = ck_tile::QuantGemmKernel; diff --git a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp index d3ee9fe9c6..828c861349 100644 --- a/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp +++ b/example/ck_tile/40_streamk_gemm/streamk_gemm_basic.cpp @@ -48,112 +48,87 @@ std::tuple gemm(const ck_tile::StreamKHostArgs& args, GemmConfiguration::NUM_WAVE_GROUPS, GemmConfiguration::PRESHUFFLE>; - const auto runKernel = [&](const auto memory_operation) -> std::tuple { - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = - ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = + ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kernel_args = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); - ck_tile::DeviceMem workspace_data(workspace_size); + auto kernel_args = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kernel_args); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + + dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); + dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kernel_args)) + { + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); + } + + if(stream_config.log_level_ > 0) + { + // Reset sk flags to zero before each repetition of the kernel workspace_data.SetZero(); - kernel_args.workspace_ptr = workspace_data.GetDeviceBuffer(); + } - dim3 grids = Kernel::GridSize(kernel_args.tile_partitioner); - dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kernel_args)) + auto reset_data_buffers = [&]() { + if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + // Clear the output C tensor results after each repetition of the kernel + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); } - - if(stream_config.log_level_ > 0) + else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << UniversalGemmProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; + // Reset sk flags to zero before each repetition of the kernel + workspace_data.SetZero(); } - - auto reset_data_buffers = [&]() { - if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic) - { - // Clear the output C tensor results after each repetition of the kernel - hipGetErrorString(hipMemsetAsync( - args.e_ptr, 0, args.M * args.N * sizeof(CDataType), stream_config.stream_id_)); - } - else if constexpr(ReductionStrategy == ck_tile::StreamKReductionStrategy::Reduction) - { - // Reset sk flags to zero before each repetition of the kernel - workspace_data.SetZero(); - } - }; - - std::function preprocess = reset_data_buffers; - - float average_time = - ck_tile::launch_kernel_time_mask(stream_config, - preprocess, - ck_tile::make_kernel( - Kernel{}, grids, blocks, 0, kernel_args)); - - ck_tile::index_t num_wgs_per_tile = - kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); - return std::tuple{average_time, num_wgs_per_tile}; }; - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == ReductionStrategy) - { - return runKernel(ck_tile::integral_constant{}); - } - else // We are using ck_tile::StreamKReductionStrategy::Reduction - { - return runKernel(ck_tile::integral_constant{}); - } + std::function preprocess = reset_data_buffers; + + float average_time = + ck_tile::launch_kernel_time_mask(stream_config, + preprocess, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kernel_args)); + + ck_tile::index_t num_wgs_per_tile = kernel_args.tile_partitioner.estimate_num_wgs_per_tile(); + return std::tuple{average_time, num_wgs_per_tile}; } #include "run_gemm_example.inc" diff --git a/example/ck_tile/41_batched_contraction/batched_contraction.cpp b/example/ck_tile/41_batched_contraction/batched_contraction.cpp index f9f13c6e85..1e159a5615 100644 --- a/example/ck_tile/41_batched_contraction/batched_contraction.cpp +++ b/example/ck_tile/41_batched_contraction/batched_contraction.cpp @@ -92,67 +92,59 @@ float batched_contraction_impl(const ck_tile::BatchedContractionHostArgs; - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + using GemmPipeline = GEMM_PIPELINE; - using GemmPipeline = GEMM_PIPELINE; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = + ck_tile::BatchedContractionKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = - ck_tile::BatchedContractionKernel; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(kargs); + const dim3 blocks = Kernel::GetBlockSize(); - const dim3 grids = Kernel::GridSize(kargs); - const dim3 blocks = Kernel::GetBlockSize(); + if(!Kernel::IsSupportedArguments(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); + } - if(!Kernel::IsSupportedArguments(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping contraction!\n"); - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetKernelName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "problem: " << GemmPipelineProblem::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } + auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - auto kernel = ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs); - - return ck_tile::launch_kernel(s, kernel); - }; - - return Run(); + return ck_tile::launch_kernel(s, kernel); } #define HANDLE_CASE(G, M, N, K) \ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp index cce95cb3f1..6ce508b47d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_tile_factory.hpp @@ -116,7 +116,6 @@ struct ConvTileFactory BLOCK_GEMM.warp_tile.k, GroupedConvTraitsType::FixedGemmParams::TransposeC, // TODO:: This template parameter will be moved inside the kernel - ck_tile::memory_operation_enum::set, BLOCK_GEMM.num_wave_groups, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, SCALAR_PER_VECTOR.c>>; diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp index ad31fc52bc..91c75e3e8d 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp index 47908e0e5b..e2e165967a 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp index 083d9d9955..5ec73d780f 100644 --- a/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp +++ b/experimental/builder/test/conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp @@ -39,7 +39,6 @@ TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP1 "Default", "Intrawave", "CShuffleEpilogue", - "set", "pipeline_AgBgCrCompV3", "DoubleSmemBuffer_0", "NumWaveGroups_1", diff --git a/experimental/builder/test/test_bwd_data_instance_traits.cpp b/experimental/builder/test/test_bwd_data_instance_traits.cpp index f26b5d7caf..fe94d16a7d 100644 --- a/experimental/builder/test/test_bwd_data_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_data_instance_traits.cpp @@ -81,7 +81,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_bwd_weight_instance_traits.cpp b/experimental/builder/test/test_bwd_weight_instance_traits.cpp index c7c4e370e2..dbb3a0a8fc 100644 --- a/experimental/builder/test/test_bwd_weight_instance_traits.cpp +++ b/experimental/builder/test/test_bwd_weight_instance_traits.cpp @@ -184,7 +184,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/experimental/builder/test/test_fwd_instance_traits.cpp b/experimental/builder/test/test_fwd_instance_traits.cpp index 6dd2a4eada..ad0a2cadc6 100644 --- a/experimental/builder/test/test_fwd_instance_traits.cpp +++ b/experimental/builder/test/test_fwd_instance_traits.cpp @@ -795,7 +795,6 @@ TEST(InstanceTraits, TileInstanceStringReturnsCorrectFormat) 16 /*N_Warp_Tile*/, 16 /*K_Warp_Tile*/, GroupedConvTraitsType::FixedGemmParams::TransposeC, - ck_tile::memory_operation_enum::set /*memory_operation*/, 1 /*kNumWaveGroups*/, GroupedConvTraitsType::FixedGemmParams::FixedVectorSize, GroupedConvTraitsType::VectorSizeC>>; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 53bfa6041d..c73897f064 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -30,7 +30,6 @@ template struct CShuffleEpilogueProblem { - using AsDataType = remove_cvref_t; - using BsDataType = remove_cvref_t; - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - using DsDataType = remove_cvref_t; - using DsLayout = remove_cvref_t; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); - static constexpr index_t kMPerBlock = kM_; - static constexpr index_t kNPerBlock = kN_; - static constexpr index_t MWave = MWave_; - static constexpr index_t NWave = NWave_; - static constexpr index_t MPerXdl = MPerXdl_; - static constexpr index_t NPerXdl = NPerXdl_; - static constexpr index_t KPerXdl = KPerXdl_; - static constexpr index_t isCTransposed = isCTransposed_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr bool FixedVectorSize = FixedVectorSize_; - static constexpr index_t VectorSizeC = VectorSizeC_; - static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; - static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; - static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; - static constexpr index_t kNumWaveGroups = kNumWaveGroups_; - static constexpr index_t NumDTensor = DsDataType::size(); + using AsDataType = remove_cvref_t; + using BsDataType = remove_cvref_t; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + using DsDataType = remove_cvref_t; + using DsLayout = remove_cvref_t; + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size(); + static constexpr index_t kMPerBlock = kM_; + static constexpr index_t kNPerBlock = kN_; + static constexpr index_t MWave = MWave_; + static constexpr index_t NWave = NWave_; + static constexpr index_t MPerXdl = MPerXdl_; + static constexpr index_t NPerXdl = NPerXdl_; + static constexpr index_t KPerXdl = KPerXdl_; + static constexpr index_t isCTransposed = isCTransposed_; + static constexpr bool FixedVectorSize = FixedVectorSize_; + static constexpr index_t VectorSizeC = VectorSizeC_; + static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_; + static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_; + static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_; + static constexpr index_t kNumWaveGroups = kNumWaveGroups_; + static constexpr index_t NumDTensor = DsDataType::size(); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); @@ -105,28 +103,27 @@ struct CShuffleEpilogue ADataType, BDataType>; - using ELayout = remove_cvref_t; - using CDElementwise = remove_cvref_t; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kMPerBlock = Problem::kMPerBlock; - static constexpr index_t kNPerBlock = Problem::kNPerBlock; - static constexpr index_t MWave = Problem::MWave; - static constexpr index_t NWave = Problem::NWave; - static constexpr index_t MPerXdl = Problem::MPerXdl; - static constexpr index_t NPerXdl = Problem::NPerXdl; - static constexpr index_t KPerXdl = Problem::KPerXdl; - static constexpr index_t isCTransposed = Problem::isCTransposed; - static constexpr bool FixedVectorSize = Problem::FixedVectorSize; - static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; - static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; - static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; - static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr index_t MPerIteration = MPerXdl * MWave; - static constexpr index_t NPerIteration = NPerXdl * NWave; - static constexpr index_t NumDTensor = Problem::NumDTensor; - static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); - static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); + using ELayout = remove_cvref_t; + using CDElementwise = remove_cvref_t; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kMPerBlock = Problem::kMPerBlock; + static constexpr index_t kNPerBlock = Problem::kNPerBlock; + static constexpr index_t MWave = Problem::MWave; + static constexpr index_t NWave = Problem::NWave; + static constexpr index_t MPerXdl = Problem::MPerXdl; + static constexpr index_t NPerXdl = Problem::NPerXdl; + static constexpr index_t KPerXdl = Problem::KPerXdl; + static constexpr index_t isCTransposed = Problem::isCTransposed; + static constexpr bool FixedVectorSize = Problem::FixedVectorSize; + static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN; + static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp; + static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer; + static constexpr index_t VectorSizeC = Problem::VectorSizeC; + static constexpr index_t MPerIteration = MPerXdl * MWave; + static constexpr index_t NPerIteration = NPerXdl * NWave; + static constexpr index_t NumDTensor = Problem::NumDTensor; + static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave); + static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave); CDElementwise elfunc_; @@ -142,8 +139,7 @@ struct CShuffleEpilogue concat('x', MWave, NWave), concat('x', MPerXdl, NPerXdl, KPerXdl), VectorSizeC, - isCTransposed ? "CTransposed" : "CNotTransposed", - mem_op_string()); + isCTransposed ? "CTransposed" : "CNotTransposed"); // clang-format on } @@ -445,7 +441,8 @@ struct CShuffleEpilogue CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window, const COutTensor& c_out_tensor) { - if constexpr(MemoryOperation == memory_operation_enum::set) + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } @@ -617,7 +614,8 @@ struct CShuffleEpilogue }); // store/update - if constexpr(MemoryOperation == memory_operation_enum::set) + if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::set) { store_tile(out_dram_window, c_out_tensor); } diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index cc2303582e..aafe7b9f58 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -15,17 +15,15 @@ template + bool UseRawStore_ = true> struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; - static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; - static constexpr index_t NumDTensor = 0; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr index_t NumDTensor = 0; }; template -struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem + bool UseRawStore_ = true> +struct DefaultGemm2DEpilogueProblem + : public Default2DEpilogueProblem { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; @@ -81,7 +74,6 @@ struct Default2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; - static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } @@ -102,7 +94,10 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - if constexpr(MemoryOperation == memory_operation_enum::set) + // FIXME? + // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp == + // memory_operation_enum::set) + if constexpr(true) { if constexpr(is_partition_index) { @@ -123,7 +118,10 @@ struct Default2DEpilogue } else { - if constexpr(MemoryOperation == memory_operation_enum::set) + // FIXME? + // if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp == + // memory_operation_enum::set) + if constexpr(true) { if constexpr(is_partition_index) { diff --git a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp index 9a33801c8f..42dab68e91 100644 --- a/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp @@ -558,21 +558,19 @@ struct FlatmmKernel return DTesnorIsValid; } - template - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -581,25 +579,81 @@ struct FlatmmKernel { return make_naive_tensor_view( a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); - index_t kFlatK = - FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); - index_t kFlatN = kargs.N * kargs.K / kFlatK; - const auto& b_flat_tensor_view = [&]() { - return make_naive_tensor_view( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } }(); + // Step 3: Create tile window + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view + index_t kFlatK = + FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2)); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + + const auto& b_flat_tensor_view = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + + // Step 2: No padding needed for b_flat + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -625,7 +679,56 @@ struct FlatmmKernel }, number{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -647,98 +750,8 @@ struct FlatmmKernel } }(); - constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; - constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; - - constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; - constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; - - auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale - : 1; // per-token scale - auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale - : 1; // per-channel scale - - static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1, - "only support per-tensor or per-row scaling"); - static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1, - "only support per-tensor or per-column scaling"); - - const auto scale_m_view = make_naive_tensor_view( - kargs.scale_m_ptr.ptr, - make_tuple(kargs.M / ScaleGranularityM, - ScaleGranularityKA == 0 - ? 1 - : splitk_batch_offset.splitted_k / - (ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)), - make_tuple(scale_stride_m, 0), - number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, - number<1>{}); - const auto scale_n_view = make_naive_tensor_view( - kargs.scale_n_ptr.ptr, - make_tuple(ScaleGranularityKB == 0 - ? 1 - : (splitk_batch_offset.splitted_k / - (ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)), - kargs.N / ScaleGranularityN), - make_tuple(0, scale_stride_n), - number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, - number<1>{}); - - return make_tuple(a_tensor_view, - b_flat_tensor_view, - ds_tensor_view, - e_tensor_view, - scale_m_view, - scale_n_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -755,93 +768,72 @@ struct FlatmmKernel } }(); - return make_tuple(a_pad_view, - b_flat_tensor_view, - ds_pad_view, - e_pad_view, - views.at(number<4>{}), - views.at(number<5>{})); - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } - constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK; - constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK; + template + CK_TILE_DEVICE static auto MakeScaleMWindow(const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_m) + { + constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN; + constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK; - auto scale_m_window = make_tile_window(views.at(number<4>{}), - make_tuple(number{}, - number < ScaleGranularityKA == 0 - ? TilePartitioner::NPerBlock - : TilePartitioner::KPerBlock > {}), - {i_m, 0}); - auto scale_n_window = make_tile_window(views.at(number<5>{}), - make_tuple(number < ScaleGranularityKB == 0 - ? TilePartitioner::MPerBlock - : TilePartitioner::KPerBlock > {}, - number{}), - {0, i_n}); + auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale + : 1; // per-token scale - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_m_window, - scale_n_window); + // Step 1: Create tensor view + const auto scale_m_view = make_naive_tensor_view( + kargs.scale_m_ptr.ptr, + make_tuple(kargs.M / ScaleGranularityM, + ScaleGranularityKA == 0 + ? 1 + : (splitk_batch_offset.splitted_k / ScaleGranularityKA)), + make_tuple(scale_stride_m, 0), + number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window(scale_m_view, + make_tuple(number{}, + number < ScaleGranularityKA == 0 + ? TilePartitioner::NPerBlock + : TilePartitioner::KPerBlock > {}), + {block_idx_m, 0}); + } + + template + CK_TILE_DEVICE static auto MakeScaleNWindow(const KernelArgs& kargs, + const SplitKBatchOffset& splitk_batch_offset, + const index_t block_idx_n) + { + constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN; + constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK; + + auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale + : 1; // per-channel scale + + // Step 1: Create tensor view + const auto scale_n_view = make_naive_tensor_view( + kargs.scale_n_ptr.ptr, + make_tuple( + ScaleGranularityKB == 0 ? 1 : (splitk_batch_offset.splitted_k / ScaleGranularityKB), + kargs.N / ScaleGranularityN), + make_tuple(0, scale_stride_n), + number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window(scale_n_view, + make_tuple(number < ScaleGranularityKB == 0 + ? TilePartitioner::MPerBlock + : TilePartitioner::KPerBlock > {}, + number{}), + {0, block_idx_n}); } template @@ -857,45 +849,74 @@ struct FlatmmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_m_window = MakeScaleMWindow(kargs, splitk_batch_offset, block_idx_m); + const auto& scale_n_window = MakeScaleNWindow(kargs, splitk_batch_offset, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = FlatmmPipeline{}.template operator()( + const auto& c_block_tile = FlatmmPipeline{}.template operator()( a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong); - auto scale_m_window = gemm_tile_windows.at(number<4>{}); - auto scale_n_window = gemm_tile_windows.at(number<5>{}); - - // Run Epilogue Pipeline + // Run Epilogue Pipeline with k_batch dispatching if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1) { - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template - operator()( - c_block_window, - c_block_tile, - d_block_window, - smem_ptr_ping, - scale_m_window, - scale_n_window); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + scale_m_window, + scale_n_window); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + scale_m_window, + scale_n_window); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()( + e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{} + .template operator()( + e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -924,8 +945,7 @@ struct FlatmmKernel __shared__ char smem_ptr_ping[GetSmemPingSize()]; __shared__ char smem_ptr_pong[GetSmemPongSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && + if constexpr(!(EpiloguePipeline::GetVectorSizeC() % 2 != 0 && is_any_of::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); diff --git a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp index 05d50666a5..61001522b0 100644 --- a/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp @@ -100,21 +100,19 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { if constexpr(std::is_same_v) { return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); @@ -123,25 +121,80 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( a_ptr, - make_tuple(splitk_batch_offset.splitted_k, kargs.M), + make_tuple(k_size, kargs.M), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); } }(); + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {block_idx_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, block_idx_m}); + } + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view index_t kFlatK = kargs.K * BlockGemmShape::WarpTile::at(I1); index_t kFlatN = kargs.N * kargs.K / kFlatK; - const auto& b_flat_tensor_view = [&]() { - return make_naive_tensor_view( - b_flat_ptr, - make_tuple(kFlatN, kFlatK), - make_tuple(kFlatK, 1), - number{}, - number<1>{}); - }(); + const auto& b_flat_tensor_view = make_naive_tensor_view( + b_flat_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + // Step 2: No padding needed for b_flat + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -167,7 +220,56 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -189,70 +291,8 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( - reinterpret_cast(scale_n.ptr), - make_tuple(FlatScaleN, FlatScaleK), - make_tuple(FlatScaleK, 1), - number<8>{}, - number<1>{}); - - return make_tuple( - a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view, scale_b_flat_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -269,77 +309,37 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - if constexpr(std::is_same_v) - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {0, i_m}); - } - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } - auto scale_block_window = - make_tile_window(views.at(I4), - make_tuple(number{}, - number{}), - {i_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); + template + CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs, + const index_t block_idx_n) + { + auto scale_n = kargs.scale_n_ptr; - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_block_window); + // Step 1: Create tensor view + index_t FlatScaleK = + (kargs.K / decltype(scale_n)::GranularityK) * N_Pack * BlockGemmShape::WarpTile::at(I1); + index_t FlatScaleN = kargs.N / N_Pack / BlockGemmShape::WarpTile::at(I1); + + const auto scale_b_flat_view = make_naive_tensor_view( + reinterpret_cast(scale_n.ptr), + make_tuple(FlatScaleN, FlatScaleK), + make_tuple(FlatScaleK, 1), + number<8>{}, + number<1>{}); + + // Step 2: Create tile window + return make_tile_window( + scale_b_flat_view, + make_tuple(number{}, + number{}), + {block_idx_n / BlockGemmShape::WarpTile::at(I1) / N_Pack, 0}); } template @@ -355,21 +355,15 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_block_window = MakeScaleBBlockWindow(kargs, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& scale_block_window = gemm_tile_windows.at(I4); - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK || ScaleM::GranularityMN == -1 // or ScaleA is disable || ScaleN::GranularityMN == -1, // or ScaleB is disable @@ -378,6 +372,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -434,8 +453,7 @@ struct F16xMXF4FlatmmKernel : FlatmmKernel::value)) { constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1); diff --git a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp index b47ec4a829..604089b7c4 100644 --- a/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/moe_flatmm_kernel.hpp @@ -1476,7 +1476,8 @@ struct MoeFlatmmKernel c_scatter_valids[mIter]); if constexpr(!IsInputGemm || - EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add) + decltype(c_block_window.get_bottom_tensor_view())::DstInMemOp == + memory_operation_enum::atomic_add) c_scatter_tile_window.update(c_out_tensor); else c_scatter_tile_window.store(c_out_tensor); diff --git a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp index 799f8f26a9..a58d71c790 100644 --- a/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp +++ b/include/ck_tile/ops/flatmm/kernel/mx_flatmm_kernel.hpp @@ -113,32 +113,50 @@ struct MXFlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTensorViews(const ADataType* a_ptr, - const BDataType* b_flat_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const SplitKBatchOffset& splitk_batch_offset) + template + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t block_idx_m) { + // Step 1: Create tensor view const auto& a_tensor_view = [&]() { static_assert(std::is_same_v, "A tensor for mx must be RowMajor"); return make_naive_tensor_view( a_ptr, - make_tuple(kargs.M, splitk_batch_offset.splitted_k), + make_tuple(kargs.M, k_size), make_tuple(kargs.stride_A, 1), number{}, number<1>{}); }(); + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, 0}); + } + + template + CK_TILE_DEVICE static auto MakeBFlatBlockWindow(const BDataType* b_flat_ptr, + const KernelArgs& kargs, + const index_t block_idx_n) + { + // Step 1: Create tensor view with special flat layout constexpr index_t kKPerBlock = MXFlatmmPipeline::kKPerBlock; constexpr index_t kNWarpTile = BlockGemmShape::WarpTile::at(I1); constexpr index_t flatKPerBlock = kKPerBlock * kNWarpTile; const index_t kFlatKBlocks = kargs.K / kKPerBlock; const index_t kFlatN = kargs.N / kNWarpTile; - const auto& b_flat_tensor_view = [&]() { + + const auto& b_flat_tensor_view = [&]() { static_assert(flatKPerBlock % MXFlatmmPipeline::GetVectorSizeB() == 0, "wrong! vector size for B tensor"); auto&& naive_desc = make_naive_tensor_descriptor_packed( @@ -153,6 +171,22 @@ struct MXFlatmmKernel : FlatmmKernel(b_flat_ptr, desc); }(); + // Step 2: No padding for flat B + // Step 3: Create tile window + return make_tile_window( + b_flat_tensor_view, + make_tuple(number{}, + number{}), + {static_cast(block_idx_n / BlockGemmShape::WarpTile::at(I1)), 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; @@ -178,7 +212,56 @@ struct MXFlatmmKernel : FlatmmKernel{}); - // TODO: enable vector write for C in ColMajor + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows + return generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_m, block_idx_n}); + } + else + { + return make_tile_window(ds_pad_view[i], + make_tuple(number{}, + number{}), + {block_idx_n, block_idx_m}); + } + }, + number{}); + } + + template + CK_TILE_DEVICE static auto MakeEBlockWindow(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view const auto& e_tensor_view = [&]() { if constexpr(std::is_same_v) { @@ -200,92 +283,8 @@ struct MXFlatmmKernel : FlatmmKernel{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view( - reinterpret_cast(scale_a.ptr), scale_a_desc); - }(); - - // B scale tensor view - const auto& scale_b_tensor_view = [&]() { - const auto scale_b_navie_desc = make_naive_tensor_descriptor_packed( - make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); - const auto scale_b_desc = transform_tensor_descriptor( - scale_b_navie_desc, - make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), - make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<0>{}, sequence<1>{})); - - return make_tensor_view( - reinterpret_cast(scale_b.ptr), scale_b_desc); - }(); - - return make_tuple(a_tensor_view, - b_flat_tensor_view, - ds_tensor_view, - e_tensor_view, - scale_a_tensor_view, - scale_b_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - static_assert(std::is_same_v, - "A tensor for mx must be RowMajor"); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_flat_tensor_view = views.at(I1); - - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor + // Step 2: Create padded view const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); if constexpr(std::is_same_v) { return pad_tensor_view(e_tensor_view, @@ -302,79 +301,71 @@ struct MXFlatmmKernel : FlatmmKernel - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_flat_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - static_assert(std::is_same_v, - "A tensor for mx must be RowMajor"); - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - }(); - - const auto& b_flat_block_window = - make_tile_window(b_flat_pad_view, - make_tuple(number{}, - number{}), - {static_cast(i_n / BlockGemmShape::WarpTile::at(I1)), 0}); - - const auto ds_block_window = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, i_n}); - } - else - { - return make_tile_window(ds_pad_view[i], - make_tuple(number{}, - number{}), - {i_n, i_m}); - } - }, - number{}); - - auto e_block_window = make_tile_window( + // Step 3: Create tile window + return make_tile_window( e_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); + } + template + CK_TILE_DEVICE static auto MakeScaleABlockWindow(const KernelArgs& kargs, + const index_t block_idx_m) + { static constexpr int BlockScaleSize = 32; - auto scale_a_block_window = make_tile_window( - views.at(I4), + const auto&& scale_packs_m = integer_divide_ceil(kargs.M, (MXdlPack * MThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // Step 1: Create tensor view + const auto scale_a_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_m, scale_packs_k, KThreadPerXdl, MThreadPerXdl)); + const auto scale_a_desc = transform_tensor_descriptor( + scale_a_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_m, MThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto& scale_a_tensor_view = make_tensor_view( + reinterpret_cast(kargs.scale_m_ptr.ptr), scale_a_desc); + + // Step 2: Create tile window + return make_tile_window( + scale_a_tensor_view, make_tuple(number{}, number{}), - {i_m / MXdlPack, 0}); + {block_idx_m / MXdlPack, 0}); + } - auto scale_b_block_window = make_tile_window( - views.at(I5), + template + CK_TILE_DEVICE static auto MakeScaleBBlockWindow(const KernelArgs& kargs, + const index_t block_idx_n) + { + static constexpr int BlockScaleSize = 32; + + const auto&& scale_packs_n = integer_divide_ceil(kargs.N, (NXdlPack * NThreadPerXdl)); + const auto&& scale_packs_k = kargs.K / BlockScaleSize / (KXdlPack * KThreadPerXdl); + + // Step 1: Create tensor view + const auto scale_b_naive_desc = make_naive_tensor_descriptor_packed( + make_tuple(scale_packs_n, scale_packs_k, KThreadPerXdl, NThreadPerXdl)); + const auto scale_b_desc = transform_tensor_descriptor( + scale_b_naive_desc, + make_tuple(make_merge_transform(make_tuple(scale_packs_n, NThreadPerXdl)), + make_merge_transform(make_tuple(scale_packs_k, KThreadPerXdl))), + make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto& scale_b_tensor_view = make_tensor_view( + reinterpret_cast(kargs.scale_n_ptr.ptr), scale_b_desc); + + // Step 2: Create tile window + return make_tile_window( + scale_b_tensor_view, make_tuple(number{}, number{}), - {i_n / NXdlPack, 0}); - - return make_tuple(a_block_window, - b_flat_block_window, - ds_block_window, - e_block_window, - scale_a_block_window, - scale_b_block_window); + {block_idx_n / NXdlPack, 0}); } template @@ -390,22 +381,16 @@ struct MXFlatmmKernel : FlatmmKernel( - a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_flat_block_window = MakeBFlatBlockWindow(b_flat_ptr, kargs, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); + const auto& scale_a_block_window = MakeScaleABlockWindow(kargs, block_idx_m); + const auto& scale_b_block_window = MakeScaleBBlockWindow(kargs, block_idx_n); const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_flat_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& scale_a_block_window = gemm_tile_windows.at(I4); - const auto& scale_b_block_window = gemm_tile_windows.at(I5); - static_assert(ScaleM::GranularityK == ScaleN::GranularityK // have the same granK || ScaleM::GranularityMN == -1 // or ScaleA is disable || ScaleN::GranularityMN == -1, // or ScaleB is disable @@ -422,22 +407,46 @@ struct MXFlatmmKernel : FlatmmKernel( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, + c_block_tile, + ds_block_window, + smem_ptr_ping, + kargs.scale_m_ptr + block_idx_m, + kargs.scale_n_ptr + block_idx_n); + } } else if(UseDefaultScheduler || (get_warp_id() == 0)) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_ping); + if(kargs.k_batch == 1) + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } + else + { + auto e_block_window = MakeEBlockWindow( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(e_block_window, c_block_tile, ds_block_window, smem_ptr_ping); + } } } @@ -466,27 +475,17 @@ struct MXFlatmmKernel : FlatmmKernel::value)) - { - constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); - RunFlatmm(a_ptr, - b_flat_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_ping, - smem_ptr_pong, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - else - { - static_assert(false, - "Unimplemented: atomic_add with odd vector size for fp16/bf16"); - } + constexpr auto scheduler_type = (MXFlatmmPipeline::NumWaveGroups == 1); + RunFlatmm(a_ptr, + b_flat_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_ping, + smem_ptr_pong, + kargs, + splitk_batch_offset, + i_m, + i_n); partition_idx += gridDim.x; } while(UsePersistentKernel && partition_idx < total_work_tile_cnt); } diff --git a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp index 95114e8496..5ba5699dda 100644 --- a/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp @@ -361,6 +361,7 @@ struct GroupedGemmKernel * * @param a_ptr input A pointer * @param b_ptr input B pointer + * @param ds_ptr input Ds pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments @@ -381,49 +382,54 @@ struct GroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m) + .at(Base::I0); + const auto& b_block_window = + Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n) + .at(Base::I0); + const auto& d_block_window = + Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + a_block_window, b_block_window, num_loop, smem_ptr_0); + // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * - * @note The GEMM pipeline is selected in-kernel based on the number of K-loops - * and the tail-number. This is needed for the persistent tile-loop when - * we didn't have access to the K dimension on the host. + * @note RunGEMM2LDS with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. - * @param smem_ptr_1 The second start memory pointer of the shared memory block. + * @param ds_ptr input Ds pointer + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k - * batch. + * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -440,54 +446,39 @@ struct GroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - {a_ptr}, {b_ptr}, ds_ptr, c_ptr, kargs, splitk_batch_offset.splitted_k); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindows({a_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_m) + .at(Base::I0); + const auto& b_block_window = + Base::MakeBBlockWindows({b_ptr}, kargs, splitk_batch_offset.splitted_k, block_idx_n) + .at(Base::I0); + const auto& d_block_window = + Base::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I1); - const auto& d_block_window = gemm_tile_windows.at(Base::I2); - - // Get hot-loop and tail configuration const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); - const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM pipeline with compile-time branching - const auto& c_block_tile = [&]() { - if constexpr(GemmPipeline::Preshuffle) - { - // Preshuffle version - without has_hot_loop parameter - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - else - { - // Regular version - with has_hot_loop parameter - const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); - return GemmPipeline{}.template operator()(a_block_window[Base::I0], - b_block_window[Base::I0], - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0, - smem_ptr_1); - } - }(); + // Run GEMM cooperatively by whole workgroup. + const auto& c_block_tile = GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I3); - EpiloguePipeline{}.template - operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg* gemm_desc_ptr, diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp index d1fd32dc1b..47e59c4704 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_kernel.hpp @@ -222,19 +222,13 @@ struct StreamKKernel const index_t block_idx_n, const index_t k_size) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - UniversalGemmKernel::template MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, c_ptr, kargs, k_size); - - const auto& gemm_pad_views = UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); - const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); - const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + // Create block windows using specialized methods + const auto& as_block_window = + UniversalGemmKernel::MakeABlockWindows(as_ptr, kargs, k_size, block_idx_m); + const auto& bs_block_window = + UniversalGemmKernel::MakeBBlockWindows(bs_ptr, kargs, k_size, block_idx_n); + const auto& ds_block_window = + UniversalGemmKernel::MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Since num_loop can vary per WG and per iteration of the Stream-K while loop, we compute // has_hot_loop and tail_num here. This is a similar pattern used by grouped GEMM. In this @@ -243,6 +237,7 @@ struct StreamKKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], bs_block_window[UniversalGemmKernel::I0], num_loop, @@ -253,7 +248,9 @@ struct StreamKKernel if(UseDefaultScheduler || (get_warp_id() == 0)) { // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = + UniversalGemmKernel::template MakeCBlockWindows( + c_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); } @@ -525,21 +522,13 @@ struct StreamKKernel const BDataType* b_ptr = static_cast(kargs.bs_ptr[0]) + i_k_b; CDataType* c_ptr = static_cast(kargs.e_ptr); - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - UniversalGemmKernel::template MakeGemmTensorViews< - EpiloguePipeline::MemoryOperation>( - {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, k_size); - - const auto& gemm_pad_views = - UniversalGemmKernel::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - UniversalGemmKernel::MakeGemmTileWindows(gemm_pad_views, i_m, i_n); - - // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(UniversalGemmKernel::I0); - const auto& bs_block_window = gemm_tile_windows.at(UniversalGemmKernel::I1); - const auto& ds_block_window = gemm_tile_windows.at(UniversalGemmKernel::I2); + // Create block windows using specialized methods + const auto& as_block_window = + UniversalGemmKernel::MakeABlockWindows({a_ptr}, kargs, k_size, i_m); + const auto& bs_block_window = + UniversalGemmKernel::MakeBBlockWindows({b_ptr}, kargs, k_size, i_n); + const auto& ds_block_window = + UniversalGemmKernel::MakeDBlockWindows({/*ds_ptr*/}, kargs, i_m, i_n); // Since num_loop can vary per WG and per iteration of the Stream-K while loop, // we compute has_hot_loop and tail_num here. This is a similar pattern used by @@ -548,6 +537,7 @@ struct StreamKKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop_sk); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop_sk); + // Run GEMM cooperatively by whole workgroup. const auto& c_block_tile = GemmPipeline{}(as_block_window[UniversalGemmKernel::I0], bs_block_window[UniversalGemmKernel::I0], num_loop_sk, @@ -594,7 +584,8 @@ struct StreamKKernel } } - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows< + TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n); EpiloguePipeline{}( c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); } @@ -617,7 +608,8 @@ struct StreamKKernel // tensor. if(tile_started && !partner_in_tile) { - auto& c_block_window = gemm_tile_windows.at(UniversalGemmKernel::I3); + auto c_block_window = UniversalGemmKernel::template MakeCBlockWindows< + TilePartitioner::MemoryOperation>(c_ptr, kargs, i_m, i_n); EpiloguePipeline{}( c_block_window, accum_block_tile, ds_block_window, smem_ptr_0); break; diff --git a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp index a6022e8b8e..0b0f6c18ef 100644 --- a/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp +++ b/include/ck_tile/ops/gemm/kernel/streamk_gemm/streamk_gemm_tile_partitioner.hpp @@ -27,6 +27,9 @@ struct StreamKTilePartitionerBase static constexpr index_t NPerBlock = BlockGemmShapeType::kN; static constexpr index_t KPerBlock = BlockGemmShapeType::kK; static constexpr StreamKReductionStrategy ReductionStrategy = ReductionStrategyType; + static constexpr auto MemoryOperation = (ReductionStrategy == StreamKReductionStrategy::Atomic) + ? memory_operation_enum::atomic_add + : memory_operation_enum::set; StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid); diff --git a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp index 77952c9afd..65f58a8ca5 100644 --- a/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp +++ b/include/ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp @@ -254,6 +254,8 @@ struct UniversalGemmKernel static_assert(DsLayout::size() == DsDataType::size(), "The size of DsLayout and DsDataType should be the same"); + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + using KernelArgs = UniversalGemmKernelArgs; @@ -609,17 +611,13 @@ struct UniversalGemmKernel return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid; } - template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const std::array& as_ptr, - const std::array& bs_ptr, - const std::array& ds_ptr, - EDataType* e_ptr, - const KernelArgs& kargs, - const index_t k_size) + MakeABlockWindows(const std::array& as_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_m) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - + // Step 1: Create tensor views for A tensors (from MakeGemmTensorViews) const auto& as_tensor_view = generate_tuple( [&](auto i) { using AiLayout = remove_cvref_t>; @@ -645,6 +643,58 @@ struct UniversalGemmKernel }, number{}); + // Step 2: Create padded views (from MakeGemmPadViews) + const auto& as_pad_view = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(as_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(as_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows (from MakeGemmTileWindows) + const auto& as_block_window = generate_tuple( + [&](auto i) { + using AiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(as_pad_view[i], + make_tuple(number{}, + number{}), + {0, i_m}); + } + }, + number{}); + + return as_block_window; + } + + CK_TILE_DEVICE static auto + MakeBBlockWindows(const std::array& bs_ptr, + const KernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + // Step 1: Create tensor views for B tensors (from MakeGemmTensorViews) const auto& bs_tensor_view = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -733,96 +783,20 @@ struct UniversalGemmKernel }, number{}); - const auto& ds_tensor_view = generate_tuple( - [&](auto i) { - using DiLayout = remove_cvref_t>; - using DDataType_ = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.M, kargs.N), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - static_cast(ds_ptr[i]), - make_tuple(kargs.N, kargs.M), - make_tuple(kargs.stride_Ds[i], 1), - number{}, - number<1>{}); - } - }, - number{}); - - // TODO: enable vector write for C in ColMajor - const auto& e_tensor_view = [&]() { - if constexpr(std::is_same_v) - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), // arguments not matching with flatmm. - make_tuple(kargs.stride_E, 1), - number{}, - number<1>{}); - } - else - { - return make_naive_tensor_view( - e_ptr, - make_tuple(kargs.M, kargs.N), - make_tuple(1, kargs.stride_E), - number<1>{}, - number<1>{}); - } - }(); - - return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& as_pad_view = generate_tuple( - [&](auto i) { - const auto& a_tensor_view = views.at(I0); - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(a_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - const auto& b_flat_pad_view = views.at(I1); - + // Step 2: Create padded views (from MakeGemmPadViews) const auto& bs_pad_view = generate_tuple( [&](auto i) { - const auto& b_tensor_view = views.at(I1); - using BiLayout = remove_cvref_t>; + using BiLayout = remove_cvref_t>; if constexpr(std::is_same_v) { - return pad_tensor_view(b_tensor_view[i], + return pad_tensor_view(bs_tensor_view[i], make_tuple(number{}, number{}), sequence{}); } else { - return pad_tensor_view(b_tensor_view[i], + return pad_tensor_view(bs_tensor_view[i], make_tuple(number{}, number{}), sequence{}); @@ -830,86 +804,7 @@ struct UniversalGemmKernel }, number{}); - const auto& ds_pad_view = generate_tuple( - [&](auto i) { - const auto& d_tensor_view = views.at(I2); - using DiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(d_tensor_view[i], - make_tuple(number{}, - number{}), - sequence{}); - } - }, - number{}); - - // TODO vector write in for C in ColMajor - const auto& e_pad_view = [&]() { - const auto& e_tensor_view = views.at(I3); - if constexpr(std::is_same_v) - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - else - { - return pad_tensor_view(e_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - } - }(); - - if constexpr(GemmPipeline::Preshuffle) - { - // For flatmm, we need to use the flat B tensor view - return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view); - } - else - { - return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view); - } - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& as_pad_view = views.at(I0); - const auto& bs_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& e_pad_view = views.at(I3); - - const auto& as_block_window = generate_tuple( - [&](auto i) { - using AiLayout = remove_cvref_t>; - if constexpr(std::is_same_v) - { - return make_tile_window(as_pad_view[i], - make_tuple(number{}, - number{}), - {i_m, 0}); - } - else - { - return make_tile_window(as_pad_view[i], - make_tuple(number{}, - number{}), - {0, i_m}); - } - }, - number{}); - + // Step 3: Create tile windows (from MakeGemmTileWindows) const auto& bs_block_window = generate_tuple( [&](auto i) { using BiLayout = remove_cvref_t>; @@ -942,7 +837,63 @@ struct UniversalGemmKernel }, number{}); - const auto ds_block_window = generate_tuple( + return bs_block_window; + } + + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor views for D tensors (from MakeGemmTensorViews) + const auto& ds_tensor_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + using DDataType_ = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + static_cast(ds_ptr[i]), + make_tuple(kargs.N, kargs.M), + make_tuple(kargs.stride_Ds[i], 1), + number{}, + number<1>{}); + } + }, + number{}); + + // Step 2: Create padded views (from MakeGemmPadViews) + const auto& ds_pad_view = generate_tuple( + [&](auto i) { + using DiLayout = remove_cvref_t>; + if constexpr(std::is_same_v) + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(ds_tensor_view[i], + make_tuple(number{}, + number{}), + sequence{}); + } + }, + number{}); + + // Step 3: Create tile windows (from MakeGemmTileWindows) + const auto& ds_block_window = generate_tuple( [&](auto i) { using DiLayout = remove_cvref_t>; if constexpr(std::is_same_v) @@ -962,12 +913,62 @@ struct UniversalGemmKernel }, number{}); + return ds_block_window; + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindows(EDataType* e_ptr, + const KernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for E/C tensor (from MakeGemmTensorViews) + const auto& e_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_E, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + e_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_E), + number<1>{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view (from MakeGemmPadViews) + const auto& e_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(e_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window (from MakeGemmTileWindows) auto e_block_window = make_tile_window( e_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window); + return e_block_window; } /** @@ -995,30 +996,32 @@ struct UniversalGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& as_block_window = + MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& bs_block_window = + MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( as_block_window, AElementWise{}, bs_block_window, BElementWise{}, num_loop, smem_ptr_0); - if(UseDefaultScheduler || (get_warp_id() == 0)) + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); + // Run Epilogue Pipeline + if(k_batch == 1) { - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); - + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); } } @@ -1051,22 +1054,17 @@ struct UniversalGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset.splitted_k); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& as_block_window = + MakeABlockWindows(as_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& bs_block_window = + MakeBBlockWindows(bs_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& as_block_window = gemm_tile_windows.at(I0); - const auto& bs_block_window = gemm_tile_windows.at(I1); - const auto& ds_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()(as_block_window, AElementWise{}, bs_block_window, @@ -1076,9 +1074,20 @@ struct UniversalGemmKernel smem_ptr_1); // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindows( + e_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } // Non-persistent kernel entry point @@ -1119,39 +1128,30 @@ struct UniversalGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + RunGemm2LDS(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + smem_ptr_1, + kargs, + splitk_batch_offset, + i_m, + i_n); } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); - RunGemm(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - kargs, - splitk_batch_offset, - i_m, - i_n); - } + + constexpr auto scheduler_type = (GemmPipeline::NumWaveGroups == 1); + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } } @@ -1204,40 +1204,28 @@ struct UniversalGemmKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm2LDS(as_ptr, - bs_ptr, - kargs.ds_ptr, - e_ptr, - smem_ptr_0, - smem_ptr_1, - kargs, - splitk_batch_offset, - i_m, - i_n); - } - } - else - { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - EpiloguePipeline::GetVectorSizeC() % 2 != 0 && - is_any_of::value)) - { - RunGemm(as_ptr, + RunGemm2LDS(as_ptr, bs_ptr, kargs.ds_ptr, e_ptr, smem_ptr_0, + smem_ptr_1, kargs, splitk_batch_offset, i_m, i_n); - } + } + else + { + RunGemm(as_ptr, + bs_ptr, + kargs.ds_ptr, + e_ptr, + smem_ptr_0, + kargs, + splitk_batch_offset, + i_m, + i_n); } // Advance to the next work item block_id += grid_size; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index ba67a9ee4d..8aab756ccf 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -401,6 +401,592 @@ struct QuantGemmKernel index_t splitted_k; }; + CK_TILE_DEVICE static auto MakeABlockWindow(const ADataType* a_ptr, + const QuantGemmKernelArgs& kargs, + const index_t k_size, + const index_t i_m) + { + // Step 1: Create tensor view for A + const auto& a_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + a_ptr, + make_tuple(kargs.M, k_size), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + a_ptr, + make_tuple(k_size, kargs.M), + make_tuple(kargs.stride_A, 1), + number{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view + const auto& a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(a_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + const auto& a_block_window = [&]() { + if constexpr(std::is_same_v) + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {i_m, 0}); + } + else + { + return make_tile_window(a_pad_view, + make_tuple(number{}, + number{}), + {0, i_m}); + } + }(); + + return a_block_window; + } + + CK_TILE_DEVICE static auto MakeAQBlockWindow(const AQDataType* aq_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for AQ + const auto& aq_tensor_view = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + { + static_assert(std::is_same_v); + const auto aq_x = kargs.M * GemmPipeline::KPerBlockAQ; + const auto aq_y = kargs.QK_A / GemmPipeline::KPerBlockAQ; + const auto aq_desc = + make_naive_tensor_descriptor(make_tuple(aq_y, aq_x), + make_tuple(aq_x, 1), + number{}, + number<1>{}); + + const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ; + const auto aq_pad0_desc = transform_tensor_descriptor( + aq_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_right_pad_transform(aq_x, get_padding_size(aq_x, block_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + const auto pad_aq_x = aq_pad0_desc.get_lengths()[I1]; + const auto wave_tile_size = + GemmPipeline::BlockGemmShape::WarpTile::at(I0) * GemmPipeline::KPerBlockAQ; + const auto wave_tile_count_x = + ck_tile::integer_divide_ceil(pad_aq_x, wave_tile_size); + + const auto aq_unmerge_pad0_desc = transform_tensor_descriptor( + aq_pad0_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_unmerge_transform(make_tuple(wave_tile_count_x, wave_tile_size))), + make_tuple(sequence<0>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1, 2>{})); + + const auto aq_pad1_desc = transform_tensor_descriptor( + aq_unmerge_pad0_desc, + make_tuple( + make_pass_through_transform(aq_y), + make_pass_through_transform(wave_tile_count_x), + make_right_pad_transform( + wave_tile_size, get_padding_size(wave_tile_size, get_warp_size()))), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{})); + + const auto pad_wave_size = + ck_tile::integer_least_multiple(wave_tile_size, get_warp_size()); + const auto aq_merge_pad1_desc = transform_tensor_descriptor( + aq_pad1_desc, + make_tuple(make_merge_transform(make_tuple(aq_y, wave_tile_count_x)), + make_pass_through_transform(pad_wave_size)), + make_tuple(sequence<0, 1>{}, sequence<2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return make_tensor_view(aq_ptr, aq_merge_pad1_desc); + } + else if constexpr((kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::ABQuantGrouped) && + !PreshuffleQuant) + { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.QK_A), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + else // Column major AQ + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.QK_A, kargs.M), + make_tuple(kargs.stride_AQ, 1), + number{}, + number<1>{}); + } + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_naive_tensor_view( + aq_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, 0), // broadcasting over n + number<1>{}, + number<1>{}); + } + else + { + return nullptr; + } + }(); + + // Step 2: Create tile window (no padding for AQ) + const auto& aq_block_window = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped && PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto warp_m = GemmPipeline::BlockGemmShape::WarpTile::at(I0); + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_m * aqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_m / warp_m; + auto block_m_idx = i_m / block_m; + return make_tile_window( + aq_tensor_view, + make_tuple(number{}, number{}), + {block_m_idx * tile_window_height, 0}); + } + else if constexpr(kQuantType == QuantType::AQuantGrouped && !PreshuffleQuant) + { + using QuantGroupSize = remove_cvref_t; + constexpr auto aqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto block_m = TilePartitioner::MPerBlock; + if constexpr(std::is_same_v) + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else // Column major AQ + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, number{}), + {0, i_m}); + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped && !PreshuffleQuant) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + constexpr auto block_m = TilePartitioner::MPerBlock; + constexpr auto block_k = TilePartitioner::KPerBlock; + return make_tile_window( + aq_tensor_view, + make_tuple(number{}, number{}), + {i_m, 0}); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_tile_window(aq_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else + { + return nullptr; + } + }(); + + return aq_block_window; + } + + CK_TILE_DEVICE static auto MakeBBlockWindow(const BDataType* b_ptr, + const QuantGemmKernelArgs& kargs, + const index_t k_size, + const index_t i_n) + { + // Step 1: Create tensor view for B + const auto& b_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + return make_naive_tensor_view( + b_ptr, + make_tuple(k_size, kargs.N), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + else + { + if constexpr(GemmPipeline::BlockGemmShape::PermuteB) + { + constexpr index_t K1 = GemmPipeline::GetSmemPackB(); + const index_t K0 = k_size / K1; + constexpr index_t VectorSizeB = std::min(K1, GemmPipeline::GetVectorSizeB()); + const auto b_k0_n_k1_desc = + make_naive_tensor_descriptor(make_tuple(K0, kargs.N, K1), + make_tuple(kargs.N * K1, K1, I1), + number{}, + number<1>{}); + const auto b_n_k_desc = transform_tensor_descriptor( + b_k0_n_k1_desc, + make_tuple(make_merge_transform(make_tuple(K0, K1)), + make_pass_through_transform(kargs.N)), + make_tuple(sequence<0, 2>{}, sequence<1>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + return make_tensor_view(b_ptr, b_n_k_desc); + } + else + { + if constexpr(PreshuffleB) + { + index_t kFlatK = + GemmPipeline::flatKPerWarp * + (k_size / GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{})); + index_t kFlatN = kargs.N * kargs.K / kFlatK; + return make_naive_tensor_view( + b_ptr, + make_tuple(kFlatN, kFlatK), + make_tuple(kFlatK, 1), + number{}, + number<1>{}); + } + else + { + if constexpr(std::is_same_v) + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size / 2), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + else + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); + } + } + } + }(); + + // Step 2: Create padded view (or flat view for PreshuffleB) + const auto& b_pad_view = [&]() { + if constexpr(PreshuffleB) + { + return b_tensor_view; // no padding for preshuffle + } + else if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + else + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + const auto& b_block_window = [&]() { + if constexpr(PreshuffleB) + { + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {static_cast(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(I1)), 0}); + } + else + { + if constexpr(std::is_same_v) + { + if constexpr(std::is_same_v) + return make_tile_window( + b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + else + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); + } + else + { + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {0, i_n}); + } + } + }(); + + return b_block_window; + } + + CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for BQ + const auto& bq_tensor_view = [&]() { + if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_naive_tensor_view( + bq_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(0, 1), // broadcasting over m + number<1>{}, + number<1>{}); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v, + "PreshuffleQuant with BQuantGrouped currently only supports " + "ColumnMajor BQ layout"); + + return MakePreshuffledQuantTensorView< + GemmPipeline::KPerBlockBQ, + GemmPipeline::NPerBlock, + TilePartitioner::BlockGemmShape::WarpTile::at(I1), + GemmPipeline::GetVectorSizeBQ()>(bq_ptr, kargs.N, kargs.QK_B); + } + else + { + using QuantGroupSize = remove_cvref_t; + + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + number{}, + number<1>{}); + } + else + { + static_assert(std::is_same_v); + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + integer_divide_ceil(kargs.K, QuantGroupSize::kK)), + make_tuple(integer_divide_ceil(kargs.K, QuantGroupSize::kK), 1), + number{}, + number<1>{}); + } + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_naive_tensor_view( + bq_ptr, + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), kargs.QK_B), + make_tuple(kargs.stride_BQ, 1), + number{}, + number<1>{}); + } + else + { + return nullptr; + } + }(); + + // Step 2: Create tile window (no padding for BQ) + const auto& bq_block_window = [&]() { + if constexpr(kQuantType == QuantType::RowColQuant) + { + return make_tile_window(bq_tensor_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + using QuantGroupSize = remove_cvref_t; + if constexpr(PreshuffleQuant) + { + static_assert(std::is_same_v); + constexpr auto block_n = TilePartitioner::NPerBlock / QuantGroupSize::kN; + constexpr auto warp_n = TilePartitioner::BlockGemmShape::WarpTile::at(I1); + constexpr auto bqk_per_block = TilePartitioner::KPerBlock / QuantGroupSize::kK; + constexpr auto tile_window_width = + ck_tile::integer_least_multiple(warp_n * bqk_per_block, get_warp_size()); + constexpr auto tile_window_height = block_n / warp_n; + auto block_n_idx = i_n / block_n; + + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, number{}), + {block_n_idx * tile_window_height, 0}); + } + else + { + if constexpr(std::is_same_v) + { + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {0, i_n / QuantGroupSize::kN}); + } + else + { + static_assert(std::is_same_v); + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } + } + } + else if constexpr(kQuantType == QuantType::ABQuantGrouped) + { + static_assert(std::is_same_v); + using QuantGroupSize = remove_cvref_t; + return make_tile_window( + bq_tensor_view, + make_tuple(number{}, + number{}), + {i_n / QuantGroupSize::kN, 0}); + } + else + { + return nullptr; + } + }(); + + return bq_block_window; + } + + template + CK_TILE_DEVICE static auto MakeCBlockWindow(CDataType* c_ptr, + const QuantGemmKernelArgs& kargs, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for C + const auto& c_tensor_view = [&]() { + if constexpr(std::is_same_v) + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(kargs.stride_C, 1), + number{}, + number<1>{}); + } + else + { + return make_naive_tensor_view( + c_ptr, + make_tuple(kargs.M, kargs.N), + make_tuple(1, kargs.stride_C), + number<1>{}, + number<1>{}); + } + }(); + + // Step 2: Create padded view + const auto& c_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + else + { + return pad_tensor_view(c_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + } + }(); + + // Step 3: Create tile window + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return c_block_window; + } + CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { if(kargs.k_batch != 1) @@ -1143,9 +1729,7 @@ struct QuantGemmKernel * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr, const BDataType* b_ptr, const AQDataType* aq_ptr, @@ -1157,25 +1741,22 @@ struct QuantGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = [&]() { if constexpr(kQuantType == QuantType::AQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(I1); - index_t m = 0; + index_t m = 0; if constexpr(PreshuffleQuant) { m = kargs.M; @@ -1185,8 +1766,7 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::BQuantGrouped) { - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t n = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { n = kargs.N; @@ -1196,10 +1776,8 @@ struct QuantGemmKernel } else if constexpr(kQuantType == QuantType::ABQuantGrouped) { - const auto& aq_block_window = gemm_tile_windows.at(I1); - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t m = 0; - index_t n = 0; + index_t m = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { m = kargs.M; @@ -1222,86 +1800,111 @@ struct QuantGemmKernel } }(); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I4); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - if constexpr(kQuantType == QuantType::ABQuantGrouped || - kQuantType == QuantType::AQuantGrouped || - kQuantType == QuantType::BQuantGrouped) + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } - else if constexpr(kQuantType == QuantType::RowColQuant) + else { - const auto& aq_block_window = gemm_tile_windows.at(I1); - const auto& bq_block_window = gemm_tile_windows.at(I3); - EpiloguePipeline{}(c_block_window, - c_block_tile, - c_block_window, - smem_ptr_0, - aq_block_window, - bq_block_window); - } - else if constexpr(kQuantType == QuantType::TensorQuant) - { - // TODO: why doesn't readfirstlane work here? - // const AccDataType aq_scale = - // __builtin_amdgcn_readfirstlane(type_convert(*aq_ptr)); - // const AccDataType bq_scale = - // __builtin_amdgcn_readfirstlane(type_convert(*bq_ptr)); - const AccDataType aq_scale = type_convert(*aq_ptr); - const AccDataType bq_scale = type_convert(*bq_ptr); - EpiloguePipeline{}( - c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::ABQuantGrouped || + kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * + * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * * @param a_ptr input A pointer * @param b_ptr input B pointer * @param aq_ptr input AQ pointer + * @param bq_ptr input BQ pointer * @param c_ptr output C pointer - * @param smem_ptr_0 The start memory pointer of the shared memory block. + * @param smem_ptr_0 The starting pointer of 1st shared memory block. + * @param smem_ptr_1 The starting pointer of 2nd shared memory block. * @param kargs GEMM kernel arguments - * @param splitk_batch_offset splitk_batch_offset Utility structure used to calculate k batch. + * @param splitk_batch_offset Utility structure used to calculate k batch. * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * - * @tparam DstInMemOp Destination memory operation (default: set). */ - template CK_TILE_DEVICE static void RunGemm2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, - void* smem_ptr_0, - void* smem_ptr_1, + void* __restrict__ smem_ptr_0, + void* __restrict__ smem_ptr_1, const QuantGemmKernelArgs& kargs, const SplitKBatchOffset& splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - - const index_t num_loop = __builtin_amdgcn_readfirstlane( - TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); + const index_t num_loop = + amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = [&]() { if constexpr(kQuantType == QuantType::BQuantGrouped) { - const auto& bq_block_window = gemm_tile_windows.at(I3); - index_t n = 0; + index_t n = 0; if constexpr(PreshuffleQuant) { n = kargs.N; @@ -1320,19 +1923,23 @@ struct QuantGemmKernel } }(); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I4); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); + // Run Epilogue Pipeline with k_batch dispatch if constexpr(kQuantType == QuantType::BQuantGrouped) { - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else - { - return; - // throw std::runtime_error("DoubleSmemBuffer Not implemented for AQuantGrouped or - // RowColQuant"); static_assert(kQuantType == QuantType::BQuantGrouped, - // "DoubleSmemBuffer Not implemented"); + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } } } @@ -1343,16 +1950,19 @@ struct QuantGemmKernel const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock); const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock); const SplitKBatchOffset splitk_batch_offset(kargs); - // options - const ADataType* a_ptr = static_cast(kargs.a_ptr); - const BDataType* b_ptr = static_cast(kargs.b_ptr); + + // Apply splitk offset to input pointers + const ADataType* a_ptr = + static_cast(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset; + const BDataType* b_ptr = + static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr_0[GetSmemSize()]; - assert(kargs.k_batch == 1); + if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 7e246961cb..1c98a372be 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -374,7 +374,7 @@ struct QuantGroupedGemmKernel CK_TILE_DEVICE static void RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr, const BDataType* b_ptr, - const AQDataType* aq_ptr, + [[maybe_unused]] const AQDataType* aq_ptr, const BQDataType* bq_ptr, CDataType* c_ptr, void* smem_ptr_0, @@ -385,25 +385,21 @@ struct QuantGroupedGemmKernel const index_t block_idx_n) { static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped"); - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& bq_block_window = + Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I2); - - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, + // Run GEMM cooperatively by whole workgroup + const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, bq_block_window, num_loop, @@ -411,10 +407,20 @@ struct QuantGroupedGemmKernel smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + // Run Epilogue Pipeline with split_k dispatch + if(kargs.k_batch == 1) + { + auto c_block_window = Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else + { + auto c_block_window = + Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } } /** @@ -449,16 +455,15 @@ struct QuantGroupedGemmKernel const index_t block_idx_m, const index_t block_idx_n) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - Base::template MakeGemmTensorViews( - a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset); - - const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); - const auto& a_block_window = gemm_tile_windows.at(Base::I0); - const auto& b_block_window = gemm_tile_windows.at(Base::I2); + // Create block windows using specialized methods + const auto& a_block_window = + Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); + const auto& b_block_window = + Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); + const auto& aq_block_window = + Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = + Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( @@ -466,51 +471,77 @@ struct QuantGroupedGemmKernel const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - if constexpr(kQuantType == QuantType::AQuantGrouped) + // Run GEMM cooperatively by whole workgroup + const auto& c_block_tile = [&]() { + if constexpr(kQuantType == QuantType::AQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + aq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::BQuantGrouped) + { + return GemmPipeline{}.template operator()(a_block_window, + b_block_window, + bq_block_window, + num_loop, + has_hot_loop, + tail_num, + smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant || + kQuantType == QuantType::TensorQuant) + { + return GemmPipeline{}.template operator()( + a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); + } + }(); + + // Run Epilogue Pipeline with split_k dispatch + if(kargs.k_batch == 1) { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - aq_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); + auto c_block_window = Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - // Run Epilogue Pipeline - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); - } - else if constexpr(kQuantType == QuantType::BQuantGrouped) - { - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, - b_block_window, - bq_block_window, - num_loop, - has_hot_loop, - tail_num, - smem_ptr_0); - - auto& c_block_window = gemm_tile_windows.at(Base::I4); - - // Run Epilogue Pipeline - EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + if constexpr(kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) + { + EpiloguePipeline{}(c_block_window, + c_block_tile, + c_block_window, + smem_ptr_0, + aq_block_window, + bq_block_window); + } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } else { - // Run GEMM pipeline - const auto& c_block_tile = GemmPipeline{}.template operator()( - a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(Base::I4); - if constexpr(kQuantType == QuantType::RowColQuant) + auto c_block_window = + Base::template MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + if constexpr(kQuantType == QuantType::AQuantGrouped || + kQuantType == QuantType::BQuantGrouped) + { + EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0); + } + else if constexpr(kQuantType == QuantType::RowColQuant) { - const auto& aq_block_window = gemm_tile_windows.at(Base::I1); - const auto& bq_block_window = gemm_tile_windows.at(Base::I3); EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp index ad445e17a7..2e5f536ab7 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp @@ -617,6 +617,117 @@ struct GroupedConvolutionBackwardDataKernel return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize()); } + CK_TILE_DEVICE static auto + MakeABlockWindow(const OutDataType* a_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_k) + { + // Step 1: Create tensor view for A (Out tensor) + const auto& a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_descs_m_k[group_id]); + + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto a_block_window = make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {i_m, i_k}); + + return a_block_window; + } + + CK_TILE_DEVICE static auto + MakeBBlockWindow(const InDataType* b_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_n, + const index_t i_k) + { + // Step 1: Create tensor view for B (Weight tensor) + const auto& b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_descs_n_k[group_id]); + + // Step 2: Create padded view + const auto& b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto b_block_window = make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {i_k, i_n}); + + return b_block_window; + } + + CK_TILE_DEVICE static auto + MakeDBlockWindows(const std::array& ds_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_n) + { + // Create D tensor block windows + const auto ds_block_window = generate_tuple( + [&](auto i) { + // Step 1: Create tensor view for D + const auto& d_tensor_view = make_tensor_view( + static_cast(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]); + + // Step 2: Create padded view + const auto& d_pad_view = + pad_tensor_view(d_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window(d_pad_view, + make_tuple(number{}, + number{}), + {i_m, i_n}); + }, + number{}); + + return ds_block_window; + } + + template + CK_TILE_DEVICE static auto + MakeCBlockWindow(WeiDataType* c_ptr, + const GroupedConvBwdDataKernelArgsSpecialized& kargs, + const index_t group_id, + const index_t i_m, + const index_t i_n) + { + // Step 1: Create tensor view for C (Input tensor) + const auto& c_tensor_view = make_tensor_view( + c_ptr, kargs.c_grid_descs_m_n[group_id]); + + // Step 2: Create padded view + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + auto c_block_window = make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {i_m, i_n}); + + return c_block_window; + } + CK_TILE_HOST static bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized& kargs) { @@ -895,38 +1006,49 @@ struct GroupedConvolutionBackwardDataKernel const index_t block_idx_k, const index_t group_id) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); + const auto& d_block_window = + MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); - // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** * @brief Runs single GEMM problem cooperatively by whole workgroup. * - * @note RunGEMM2LDS in with two shared memory buffers using the ping pong buffer mechanism. + * @note RunGemm2LDS in with two shared memory buffers using the ping pong buffer mechanism. * * @param a_ptr input A pointer * @param b_ptr input B pointer @@ -951,23 +1073,19 @@ struct GroupedConvolutionBackwardDataKernel const index_t block_idx_k, const index_t group_id) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); + // Create block windows using specialized methods + const auto& a_block_window = + MakeABlockWindow(a_ptr, kargs, group_id, block_idx_m, block_idx_k); + const auto& b_block_window = + MakeBBlockWindow(b_ptr, kargs, group_id, block_idx_n, block_idx_k); + const auto& d_block_window = + MakeDBlockWindows(ds_ptr, kargs, group_id, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitted_k)); const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop); const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, @@ -976,11 +1094,27 @@ struct GroupedConvolutionBackwardDataKernel smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + const index_t k_batch = amd_wave_read_first_lane(kargs.k_batch); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + // Run Epilogue Pipeline with k_batch dispatch + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, group_id, block_idx_m, block_idx_n); + + EpiloguePipeline{} + .template operator()( + c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized& kargs, @@ -1066,8 +1200,7 @@ struct GroupedConvolutionBackwardDataKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1086,8 +1219,7 @@ struct GroupedConvolutionBackwardDataKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp index 4b7ad72ffc..6bcd05e9ba 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp @@ -518,25 +518,6 @@ struct GroupedConvolutionBackwardWeightKernel return false; } -#if defined(__gfx11__) - if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set) - { - return false; - } -#endif - - if constexpr(EpiloguePipeline_::MemoryOperation == memory_operation_enum::atomic_add) - { - if(kargs.k_batch == 1) - { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) - { - CK_TILE_ERROR("Atomic add epilogue only supports k_batch > 1."); - } - return false; - } - } - if constexpr(!std::is_same_v && !std::is_same_v) { @@ -704,29 +685,31 @@ struct GroupedConvolutionBackwardWeightKernel template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const OutDataType* a_ptr, - const InDataType* b_ptr, - const std::array& ds_ptr, - WeiDataType* c_ptr, - const GroupedConvBwdWeightKernelArgsSpecialized& kargs) + MakeCBlockWindow(WeiDataType* c_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); - const auto& a_tensor_view = [&]() { - return make_tensor_view(a_ptr, - kargs.a_grid_desc_k_m); // A: out - }(); + const auto& c_tensor_view = + make_tensor_view(c_ptr, kargs.c_grid_desc_m_n); - const auto& b_tensor_view = [&]() { - return make_tensor_view(b_ptr, - kargs.b_grid_desc_k_n); // B: in - }(); + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); - const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, - kargs.c_grid_desc_m_n); - }(); + return make_tile_window( + c_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, block_idx_n}); + } + CK_TILE_DEVICE static auto + MakeDBlockWindows(const std::array& ds_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_n) + { const auto& ds_tensor_view = generate_tuple( [&](auto i) { static_assert(std::is_same_v, OutLayout>, @@ -741,30 +724,7 @@ struct GroupedConvolutionBackwardWeightKernel }, number{}); - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& ds_tensor_view = views.at(I2); - const auto& ds_pad_view = generate_tuple( + const auto& ds_pad_view = generate_tuple( [&](auto i) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, @@ -773,67 +733,58 @@ struct GroupedConvolutionBackwardWeightKernel }, number{}); - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); - } - - /** - * @brief Create views to the data that each workgroup will process. - * - * @param views padded views of A, B, D and C tensors - * @param i_m block m-index - * @param i_n block n-index - * @param i_k block k-index - * - * @return tuple of tile windows for A, B, D and C tensors - */ - template - CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views, - const index_t i_m, - const index_t i_n, - const index_t i_k) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_k, i_m}); - }(); - - const auto& b_block_window = [&]() { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_k, i_n}); - }(); - - const auto ds_block_window = generate_tuple( + return generate_tuple( [&](auto i) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); }, number{}); + } - auto c_block_window = make_tile_window( - c_pad_view, - make_tuple(number{}, number{}), - {i_m, i_n}); + CK_TILE_DEVICE static auto + MakeBBlockWindow(const InDataType* b_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_n, + const index_t block_idx_k) + { + static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); + const auto& b_tensor_view = + make_tensor_view(b_ptr, kargs.b_grid_desc_k_n); - return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + const auto& b_pad_view = + pad_tensor_view(b_tensor_view, + make_tuple(number{} * kargs.k_batch, + number{}), + sequence{}); + + return make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {block_idx_k, block_idx_n}); + } + + CK_TILE_DEVICE static auto + MakeABlockWindow(const OutDataType* a_ptr, + const GroupedConvBwdWeightKernelArgsSpecialized& kargs, + const index_t block_idx_m, + const index_t block_idx_k) + { + static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); + const auto& a_tensor_view = + make_tensor_view(a_ptr, kargs.a_grid_desc_k_m); + + const auto& a_pad_view = + pad_tensor_view(a_tensor_view, + make_tuple(number{} * kargs.k_batch, + number{}), + sequence{}); + + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_k, block_idx_m}); } /** @@ -859,28 +810,30 @@ struct GroupedConvolutionBackwardWeightKernel const index_t block_idx_n, const index_t block_idx_k) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + // Create block windows using helper methods + const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); + const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); + const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } /** @@ -910,27 +863,33 @@ struct GroupedConvolutionBackwardWeightKernel const index_t block_idx_n, const index_t block_idx_k) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, kargs); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = - MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k); + // Create block windows using helper methods + const auto& a_block_window = MakeABlockWindow(a_ptr, kargs, block_idx_m, block_idx_k); + const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, block_idx_n, block_idx_k); + const auto& d_block_window = MakeDBlockWindows(ds_ptr, kargs, block_idx_m, block_idx_n); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(kargs.k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); - EpiloguePipeline{}.template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } + else + { +#if defined(__gfx11__) + return; +#endif + auto c_block_window = MakeCBlockWindow( + c_ptr, kargs, block_idx_m, block_idx_n); + + EpiloguePipeline{}(c_block_window, c_block_tile, d_block_window, smem_ptr_0); + } } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const @@ -960,12 +919,6 @@ struct GroupedConvolutionBackwardWeightKernel CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized& kargs) const { -#if defined(__gfx11__) - if constexpr(EpiloguePipeline::MemoryOperation != ck_tile::memory_operation_enum::set) - { - return; - } -#endif if constexpr(GroupedConvTraitsType_::ExplicitGemm) { CallExplicitGemm(kargs); @@ -1001,9 +954,7 @@ struct GroupedConvolutionBackwardWeightKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1021,9 +972,7 @@ struct GroupedConvolutionBackwardWeightKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, diff --git a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp index 0f143d7ff7..1b81bce34a 100644 --- a/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp +++ b/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp @@ -794,34 +794,53 @@ struct GroupedConvolutionForwardKernel return true; } - template + template CK_TILE_DEVICE static auto - MakeGemmTensorViews(const InDataType* a_ptr, - const WeiDataType* b_ptr, - const std::array& ds_ptr, - OutDataType* c_ptr, - const ADescType& a_desc, - const BDescType& b_desc, - const CDescType& c_desc) + MakeABlockWindow(const InDataType* a_ptr, const ADescType& a_desc, const index_t block_idx_m) { - static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!"); - static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!"); - const auto& a_tensor_view = [&]() { - return make_tensor_view(a_ptr, a_desc); - }(); + // Step 1: Create tensor view + const auto& a_tensor_view = make_tensor_view(a_ptr, a_desc); - const auto& b_tensor_view = [&]() { - return make_tensor_view(b_ptr, b_desc); - }(); + // Step 2: Create padded view + const auto& a_pad_view = pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); - // TODO: enable vector write for C in ColMajor - const auto& c_tensor_view = [&]() { - return make_tensor_view(c_ptr, c_desc); - }(); + // Step 3: Create tile window + return make_tile_window( + a_pad_view, + make_tuple(number{}, number{}), + {block_idx_m, 0}); + } + template + CK_TILE_DEVICE static auto + MakeBBlockWindow(const WeiDataType* b_ptr, const BDescType& b_desc, const index_t block_idx_n) + { + // Step 1: Create tensor view + const auto& b_tensor_view = make_tensor_view(b_ptr, b_desc); + + // Step 2: Create padded view + const auto& b_pad_view = pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( + b_pad_view, + make_tuple(number{}, number{}), + {block_idx_n, 0}); + } + + template + CK_TILE_DEVICE static auto MakeDBlockWindows(const std::array& ds_ptr, + const CDescType& c_desc, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor views const auto& ds_tensor_view = generate_tuple( [&](auto i) { static_assert(std::is_same_v, OutLayout>, @@ -836,30 +855,8 @@ struct GroupedConvolutionForwardKernel }, number{}); - return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view); - } - - template - CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views) - { - const auto& a_pad_view = [&]() { - const auto& a_tensor_view = views.at(I0); - return pad_tensor_view(a_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& b_pad_view = [&]() { - const auto& b_tensor_view = views.at(I1); - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - const auto& ds_tensor_view = views.at(I2); - const auto& ds_pad_view = generate_tuple( + // Step 2: Create padded views + const auto& ds_pad_view = generate_tuple( [&](auto i) { return pad_tensor_view(ds_tensor_view[i], make_tuple(number{}, @@ -868,55 +865,38 @@ struct GroupedConvolutionForwardKernel }, number{}); - const auto& c_pad_view = [&]() { - const auto& c_tensor_view = views.at(I3); - return pad_tensor_view(c_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - }(); - - return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view); - } - - template - CK_TILE_DEVICE static auto - MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n) - { - const auto& a_pad_view = views.at(I0); - const auto& b_pad_view = views.at(I1); - const auto& ds_pad_view = views.at(I2); - const auto& c_pad_view = views.at(I3); - - const auto& a_block_window = [&]() { - return make_tile_window(a_pad_view, - make_tuple(number{}, - number{}), - {i_m, 0}); - }(); - - const auto& b_block_window = [&]() { - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - }(); - - const auto ds_block_window = generate_tuple( + // Step 3: Create tile windows + return generate_tuple( [&](auto i) { return make_tile_window(ds_pad_view[i], make_tuple(number{}, number{}), - {i_m, i_n}); + {block_idx_m, block_idx_n}); }, number{}); + } - auto c_block_window = make_tile_window( + template + CK_TILE_DEVICE static auto MakeCBlockWindow(OutDataType* c_ptr, + const CDescType& c_desc, + const index_t block_idx_m, + const index_t block_idx_n) + { + // Step 1: Create tensor view + const auto& c_tensor_view = + make_tensor_view(c_ptr, c_desc); + + // Step 2: Create padded view + const auto& c_pad_view = pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + + // Step 3: Create tile window + return make_tile_window( c_pad_view, make_tuple(number{}, number{}), - {i_m, i_n}); - - return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window); + {block_idx_m, block_idx_n}); } /** @@ -931,6 +911,7 @@ struct GroupedConvolutionForwardKernel * @param b_desc Weight tensor B descriptor * @param c_desc Output tensor C descriptor * @param gemm_k The GEMM K dimension + * @param k_batch The K batch parameter for split-K * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -945,34 +926,41 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc, const index_t gemm_k, + const index_t k_batch, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise& elfunc) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc); - - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); + const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); + + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } /** @@ -990,6 +978,7 @@ struct GroupedConvolutionForwardKernel * @param b_desc Weight tensor B descriptor * @param c_desc Output tensor C descriptor * @param gemm_k The GEMM K dimension + * @param k_batch The K batch parameter for split-K * @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup. * @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup. * @@ -1005,33 +994,41 @@ struct GroupedConvolutionForwardKernel const BDescType& b_desc, const CDescType& c_desc, const index_t gemm_k, + const index_t k_batch, const index_t block_idx_m, const index_t block_idx_n, const CDElementwise& elfunc) { - // Create Gemm tensor views, pad views and tile windows - const auto& gemm_tensor_views_tuple = - MakeGemmTensorViews( - a_ptr, b_ptr, ds_ptr, c_ptr, a_desc, b_desc, c_desc); - const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple); - auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n); + // Create block windows using specialized methods + const auto& a_block_window = MakeABlockWindow(a_ptr, a_desc, block_idx_m); + const auto& b_block_window = MakeBBlockWindow(b_ptr, b_desc, block_idx_n); + const auto& ds_block_window = MakeDBlockWindows(ds_ptr, c_desc, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(gemm_k)); // Run GEMM cooperatively by whole workgroup. - const auto& a_block_window = gemm_tile_windows.at(I0); - const auto& b_block_window = gemm_tile_windows.at(I1); - const auto& d_block_window = gemm_tile_windows.at(I2); - const auto& c_block_tile = GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1); - // Run Epilogue Pipeline - auto& c_block_window = gemm_tile_windows.at(I3); + // Run Epilogue Pipeline with k_batch dispatching + if(k_batch == 1) + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); - EpiloguePipeline{elfunc} - .template operator()( - c_block_window, c_block_tile, d_block_window, smem_ptr_0); + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } + else + { + auto c_block_window = MakeCBlockWindow( + c_ptr, c_desc, block_idx_m, block_idx_n); + + EpiloguePipeline{elfunc} + .template operator()( + c_block_window, c_block_tile, ds_block_window, smem_ptr_0); + } } CK_TILE_DEVICE void CallExplicitGemm(GroupedConvFwdKernelArgsSpecialized& kargs) const @@ -1185,9 +1182,7 @@ struct GroupedConvolutionForwardKernel if constexpr(GemmPipeline::DoubleSmemBuffer == true) { __shared__ char smem_ptr_1[GemmPipeline::GetSmemSize()]; - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm2LDS(a_ptr, @@ -1200,6 +1195,7 @@ struct GroupedConvolutionForwardKernel b_desc, c_desc, kargs.GemmK, + kargs.k_batch, i_m, i_n, kargs.elfunc); @@ -1207,9 +1203,7 @@ struct GroupedConvolutionForwardKernel } else { - if constexpr(!(EpiloguePipeline::MemoryOperation == - memory_operation_enum::atomic_add && - GroupedConvTraitsType_::VectorSizeC % 2 != 0 && + if constexpr(!(GroupedConvTraitsType_::VectorSizeC % 2 != 0 && is_any_of::value)) { RunGemm(a_ptr, @@ -1221,6 +1215,7 @@ struct GroupedConvolutionForwardKernel b_desc, c_desc, kargs.GemmK, + kargs.k_batch, i_m, i_n, kargs.elfunc); diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp index 77eb416532..37005cccc1 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp +++ b/test/ck_tile/batched_gemm/test_batched_gemm_util.hpp @@ -99,62 +99,47 @@ class TestCkTileBatchedGemm : public ::testing::Test scheduler>; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::BatchedGemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::BatchedGemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch, args.batch_count); + const dim3 blocks = Kernel::BlockSize(); - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index 9b90110c07..0572115201 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -120,8 +120,8 @@ using SimpleCShuffleEpilogueProblem = MPerXdl, NPerXdl, KPerXdl, - false, // isCTransposed, - memory_operation_enum::set>; + false // isCTransposed + >; template auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index e949ed45e6..8dc2e88430 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -182,74 +182,58 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + const auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + const dim3 blocks = Kernel::BlockSize(); + dim3 grids; + if constexpr(Persistent) { - Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - Run(ck_tile::integral_constant{}); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 7d82958acf..6fb1b77fa8 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -356,8 +356,7 @@ class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel>; + transpose_c>>; using Kernel = ck_tile::QuantGemmKernel; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = + std::conditional_t; - using GemmEpilogue = std:: - conditional_t; + using Kernel = ck_tile::GemmKernelMultiABD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiABD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp index 8217f5a3d9..6a6806641a 100644 --- a/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp +++ b/test/ck_tile/gemm_multi_d/test_gemm_multi_d_util.hpp @@ -170,88 +170,69 @@ class TestCkTileGemmMultiD : public ::testing::Test using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< + ck_tile::DefaultGemm2DEpilogueProblem>; - using DefaultGemmEpilogue = ck_tile::DefaultGemm2DEpilogue< - ck_tile::DefaultGemm2DEpilogueProblem>; + using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using CShuffleGemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using GemmEpilogue = + std::conditional_t; - using GemmEpilogue = std:: - conditional_t; + using Kernel = ck_tile::GemmKernelMultiD; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernelMultiD; - auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << GemmShape::GetName() << '\n' - << "pipeline: " << GemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z - << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + if(!Kernel::IsSupportedArgument(kargs)) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); } - else + + if(s.log_level_ > 0) { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp index 540109a999..237dc24c3b 100644 --- a/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp +++ b/test/ck_tile/gemm_streamk/test_gemm_streamk_util.hpp @@ -105,71 +105,60 @@ class TestCkTileStreamK : public ::testing::Test NumWaveGroup, preshuffle>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - // We create the GEMM pipeline without specifying has_hot_loop or tail_num. - // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K - // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K - // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - // For initial testing, we will just test with one pipeline. - // More extensive testing is coming later and will test other pipelines. - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + // We create the GEMM pipeline without specifying has_hot_loop or tail_num. + // This is because num_loop can vary (a) per WG and (b) per iteration of the Stream-K + // while loop. Instead, has_hot_loop and tail_num are determined in the Stream-K + // Kernel's RunGemm function. This is a similar pattern used by grouped GEMM. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // For initial testing, we will just test with one pipeline. + // More extensive testing is coming later and will test other pipelines. + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - UniversalGemmProblem::TransposeC, - memory_operation>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + UniversalGemmProblem::TransposeC>>; - using Kernel = ck_tile::StreamKKernel; + using Kernel = ck_tile::StreamKKernel; - auto kargs = Kernel::MakeKernelArgs(args); - const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); - ck_tile::DeviceMem workspace_data(workspace_size); - workspace_data.SetZero(); - kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); + auto kargs = Kernel::MakeKernelArgs(args); + const auto workspace_size = Kernel::GetWorkSpaceSize(kargs); + ck_tile::DeviceMem workspace_data(workspace_size); + workspace_data.SetZero(); + kargs.workspace_ptr = workspace_data.GetDeviceBuffer(); - if(!Kernel::IsSupportedArgument(kargs)) - { - EXPECT_TRUE(false); - } + if(!Kernel::IsSupportedArgument(kargs)) + { + EXPECT_TRUE(false); + } - dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); - dim3 block_dims = Kernel::BlockSize(); + dim3 grid_dims = Kernel::GridSize(kargs.tile_partitioner); + dim3 block_dims = Kernel::BlockSize(); - ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grid_dims, block_dims, 0, kargs)); - return kargs.tile_partitioner.estimate_num_wgs_per_tile(); - }; - - return Run(ck_tile::integral_constant{}); + return kargs.tile_partitioner.estimate_num_wgs_per_tile(); } public: diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 7c085b5098..875684ce08 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -180,68 +180,52 @@ class TestCkTileGemmPipeline : public ::testing::Test using GemmPipeline = typename GemmPipelineTypeSelector::pipeline; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); - using Kernel = ck_tile::GemmKernel; - auto kargs = Kernel::MakeKernelArgs(args); - - dim3 grids; - if constexpr(Persistent) - { - grids = Kernel::MaxOccupancyGridSize(s); - } - else - { - grids = Kernel::GridSize(args.M, args.N, args.k_batch); - } - const dim3 blocks = Kernel::BlockSize(); - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " - << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " - << blocks.y << ", " << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - }; - - if(args.k_batch == 1) + dim3 grids; + if constexpr(Persistent) { - Run(ck_tile::integral_constant{}); + grids = Kernel::MaxOccupancyGridSize(s); } else { - Run(ck_tile::integral_constant{}); + grids = Kernel::GridSize(args.M, args.N, args.k_batch); } + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y + << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y + << ", " << blocks.z << "}" << std::endl; + } + + ck_tile::ignore = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); } public: diff --git a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp index bdce90e385..237641a000 100644 --- a/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp +++ b/test/ck_tile/grouped_conv/test_ck_tile_grouped_conv_bwd_weight.cpp @@ -42,8 +42,7 @@ template + index_t NDimSpatial = 2> struct BuildKernel { using GemmShape = TileGemmShape< @@ -123,7 +122,6 @@ struct BuildKernel ConvConfig::N_Warp_Tile, ConvConfig::K_Warp_Tile, ConvTraits::FixedGemmParams::TransposeC, - MemOp, ConvConfig::NumWaveGroups, ConvTraits::FixedGemmParams::FixedVectorSize, ConvTraits::VectorSizeC>; @@ -212,26 +210,6 @@ TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, InvalidKBatchLessThanOne) EXPECT_FALSE(Kernel::IsSupportedArgument(kargs)); } -TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, AtomicAddRequiresKBatchGreaterThanOne) -{ - using Kernel = typename BuildKernel::type; - - // k_batch = 1 should fail with atomic_add - auto host_args_kbatch_1 = create_2d_host_args(1); - auto kargs_1 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_1); - EXPECT_FALSE(Kernel::IsSupportedArgument(kargs_1)); - - // k_batch = 2 should pass - auto host_args_kbatch_2 = create_2d_host_args(2); - auto kargs_2 = typename Kernel::GroupedConvBwdWeightKernelArgsSpecialized(host_args_kbatch_2); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs_2)); -} - TEST_F(GroupedConvBwdWeightIsSupportedArgumentTest, K0KBatchLimitation) { using Kernel = typename BuildKernel; using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + if(s.log_level_ > 0) { - std::cout << "Run without SplitK" << std::endl; - Run(ck_tile::integral_constant{}); - } - else - { - std::cout << "Run using SplitK" << std::endl; - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } template void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { constexpr bool TransposeC = false; constexpr bool DoubleSmemBuffer = false; @@ -212,50 +193,47 @@ class TestCkTileGroupedGemm : public ::testing::Test CLayout, TransposeC>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + ck_tile::ignore = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, @@ -264,19 +242,6 @@ class TestCkTileGroupedGemm : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); - }; - - if(splitk) - { - Run(ck_tile::integral_constant{}); - } - else - { - - Run(ck_tile::integral_constant{}); - } } auto calculate_rtol_atol(const ck_tile::index_t K, @@ -422,8 +387,7 @@ class TestCkTileGroupedGemm : public ::testing::Test { // Generate kernel arguments std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = gemm_descs[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : gemm_descs) { kargs.emplace_back(ck_tile::UniversalGemmKernelArgs<>{{arg.a_ptr}, @@ -448,10 +412,10 @@ class TestCkTileGroupedGemm : public ::testing::Test stream.stream_id_)); #if CK_TILE_USE_WMMA invoke_grouped_gemm_persistent( - stream, group_count, kargs_ptr, splitk); + stream, group_count, kargs_ptr); #else invoke_grouped_gemm_persistent( - stream, group_count, kargs_ptr, splitk); + stream, group_count, kargs_ptr); #endif } else diff --git a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp index b065df6f8a..c6e311a65c 100644 --- a/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp +++ b/test/ck_tile/grouped_gemm_multi_d/test_grouped_gemm_multi_d_util.hpp @@ -96,7 +96,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test const ck_tile::stream_config& s, void* kargs_ptr) { - + EXPECT_TRUE(gemm_descs[0].k_batch == 1); using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, ck_tile::sequence, @@ -134,74 +134,56 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test ck_tile::GemmPipelineAgBgCrCompV3, ck_tile::GemmPipelineAgBgCrCompV4>>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); - - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } void invoke_grouped_gemm_persistent(const ck_tile::stream_config& s, const ck_tile::index_t num_groups, - void* kargs_ptr, - bool splitk) + void* kargs_ptr) { using GemmShape = ck_tile::TileGemmShape< ck_tile::sequence, @@ -218,78 +200,58 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test BLayout, ELayout>; - float ave_time{0}; + // We create the GEMM pipeline without specifying hotloop or tailnumber. + // These are automatically run inside the kernel based on the given input data. + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; + using GemmPipeline = std::conditional_t< + Config::Pipeline_ == (PipelineType::Memory), + ck_tile::GemmPipelineAgBgCrMem, + std::conditional_t, + ck_tile::GemmPipelineAgBgCrCompV4>>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. - using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; - - using GemmPipeline = std::conditional_t< - Config::Pipeline_ == (PipelineType::Memory), - ck_tile::GemmPipelineAgBgCrMem, - std::conditional_t, - ck_tile::GemmPipelineAgBgCrCompV4>>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } - - ave_time = ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - num_groups)); - - return ave_time; - }; - if(!splitk) + if(s.log_level_ > 0) { - Run(ck_tile::integral_constant{}); - } - else - { - Run(ck_tile::integral_constant{}); + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; } + + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + num_groups)); } public: @@ -445,8 +407,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test if constexpr(Config::Persistent_) { std::vector> kargs; - void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); - const bool splitk = gemm_descs[0].k_batch > 1; + void* kargs_ptr = gemm_workspace.GetDeviceBuffer(); for(const auto& arg : gemm_descs) { kargs.emplace_back( @@ -471,7 +432,7 @@ class TestCkTileGroupedGemmMultiD : public ::testing::Test hipMemcpyHostToDevice, stream.stream_id_)); - invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr, splitk); + invoke_grouped_gemm_persistent(stream, group_count, kargs_ptr); } else { diff --git a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp index a7189e7865..e588ad2cc1 100644 --- a/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp +++ b/test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp @@ -127,59 +127,44 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; - - if(gemm_descs[0].k_batch == 1) - { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } private: @@ -226,59 +211,45 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test ck_tile::GemmPipelineScheduler::Default>; using GemmPipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; - const auto Run = [&](const auto memory_operation_) { - constexpr auto memory_operation = memory_operation_.value; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::GroupedGemmKernel; - auto kargs = Kernel::MakeKargs(gemm_descs); - EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); - const dim3 grids = Kernel::GridSize(gemm_descs); - const dim3 blocks = Kernel::BlockSize(); - ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, - kargs.data(), - get_workspace_size(gemm_descs), - hipMemcpyHostToDevice, - s.stream_id_)); + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GroupedGemmKernel; + auto kargs = Kernel::MakeKargs(gemm_descs); + EXPECT_TRUE(Kernel::IsSupportedArgument(kargs)); + const dim3 grids = Kernel::GridSize(gemm_descs); + const dim3 blocks = Kernel::BlockSize(); - return ck_tile::launch_kernel( - s, - ck_tile::make_kernel( - Kernel{}, - grids, - blocks, - 0, - ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), - gemm_descs.size())); - }; + ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr, + kargs.data(), + get_workspace_size(gemm_descs), + hipMemcpyHostToDevice, + s.stream_id_)); - if(gemm_descs[0].k_batch == 1) - { - Run(ck_tile::integral_constant{}); - } - else - { - // EXPECT TO FAIL because splitk is not supported - EXPECT_FALSE(true); - } + ck_tile::ignore = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, + grids, + blocks, + 0, + ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), + gemm_descs.size())); } struct BShuffleGemmConfig diff --git a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp index b73221ac28..3d52bca9e0 100644 --- a/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp +++ b/test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp @@ -148,10 +148,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test float ave_time{0}; const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = ck_tile::memory_operation_enum::set; + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using QuantGemmProblem = std::conditional_t< UseGroupedQuant, @@ -217,8 +216,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test GroupedGemKernelParam::M_Warp_Tile, GroupedGemKernelParam::N_Warp_Tile, GroupedGemKernelParam::K_Warp_Tile, - QuantGemmProblem::TransposeC, - memory_operation>>; + QuantGemmProblem::TransposeC>>; using Kernel = ck_tile::QuantGroupedGemmKernel; - using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + using GemmUniversalTraits = ck_tile::TileGemmQuantTraits; + constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || + QuantType == ck_tile::QuantType::BQuantGrouped; + using QuantGemmProblem = std::conditional_t< + UseGroupedQuant, + std::conditional_t, + ck_tile::GemmBQuantPipelineProblem>, + ck_tile::GemmRowColTensorQuantPipelineProblem>; - const auto Run = [&](const auto memory_operation_) { - constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - constexpr auto memory_operation = memory_operation_.value; - // We create the GEMM pipeline without specifying hotloop or tailnumber. - // These are automatically run inside the kernel based on the given input data. + using GemmPipeline = std::conditional_t< + UseGroupedQuant, + std::conditional_t< + QuantType == ck_tile::QuantType::AQuantGrouped, + ck_tile::AQuantGemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, + ck_tile::GemmPipelineAgBgCrCompV3>; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::QuantGroupedGemmKernel; + const dim3 blocks = Kernel::BlockSize(); + const dim3 grids = Kernel::MaxOccupancyGridSize(s); - constexpr bool UseGroupedQuant = QuantType == ck_tile::QuantType::AQuantGrouped || - QuantType == ck_tile::QuantType::BQuantGrouped; - using QuantGemmProblem = std::conditional_t< - UseGroupedQuant, - std::conditional_t, - ck_tile::GemmBQuantPipelineProblem>, - ck_tile::GemmRowColTensorQuantPipelineProblem>; - - using GemmPipeline = std::conditional_t< - UseGroupedQuant, - std::conditional_t< - QuantType == ck_tile::QuantType::AQuantGrouped, - ck_tile::AQuantGemmPipelineAgBgCrCompV3, - std::conditional_t, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>, - ck_tile::GemmPipelineAgBgCrCompV3>; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem>; - using Kernel = ck_tile::QuantGroupedGemmKernel; - const dim3 blocks = Kernel::BlockSize(); - const dim3 grids = Kernel::MaxOccupancyGridSize(s); - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel: " << Kernel::GetName() - << " with args:" << " grid: {" << grids.x << ", " << grids.y << ", " - << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " - << blocks.z << "}" << std::endl; - } + if(s.log_level_ > 0) + { + std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {" + << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" + << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; + } + ck_tile::ignore = ck_tile::launch_kernel(s, ck_tile::make_kernel( Kernel{}, @@ -388,10 +379,6 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test 0, ck_tile::cast_pointer_to_constant_address_space(kargs_ptr), num_groups)); - }; - - Run(ck_tile::integral_constant{}); } template diff --git a/tile_engine/ops/gemm/gemm_instance_builder.py b/tile_engine/ops/gemm/gemm_instance_builder.py index 27ca805c2e..81a9b08b70 100644 --- a/tile_engine/ops/gemm/gemm_instance_builder.py +++ b/tile_engine/ops/gemm/gemm_instance_builder.py @@ -719,8 +719,8 @@ struct SelectedKernel {{ elif self.kernel_name_prefix in ["gemm_universal", "gemm_preshuffle"]: instance_code += f""" - // Kernel type - using GemmKernel = ck_tile::GemmKernel; + // Kernel type + using GemmKernel = ck_tile::GemmKernel; // Kernel arguments auto kargs = GemmKernel::MakeKernelArgs(args); @@ -802,8 +802,8 @@ struct SelectedKernel {{ ck_tile::tuple<>, // DsLayout CLayout, ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, // kM_ - TilePartitioner::NPerBlock, // kN_ + TileM, // kM_ + TileN, // kN_ WarpPerBlock_M, // MWave_ WarpPerBlock_N, // NWave_ WarpTileM, // MPerXdl_ diff --git a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py index 2225619fad..bea46de067 100644 --- a/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py +++ b/tile_engine/ops/gemm_streamk/gemm_streamk_instance_builder.py @@ -481,8 +481,6 @@ struct SelectedKernel {{ GemmUniversalTraits>; static float launch(const ck_tile::StreamKHostArgs& args, const ck_tile::stream_config& stream) {{ - const auto Run = [&](const auto memory_operation_) {{ - constexpr auto memory_operation = memory_operation_.value; constexpr auto scheduler = ck_tile::GemmPipelineScheduler::Intrawave; using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; // kNumWaveGroups_ using GemmEpilogue = ck_tile::CShuffleEpilogue; @@ -558,30 +555,12 @@ struct SelectedKernel {{ workspace_data.SetZero(); }} }}; - - + // Launch kernel - float ave_time = ck_tile::launch_kernel_time_mask( + return ck_tile::launch_kernel_time_mask( stream, reset_data_buffers, ck_tile::make_kernel(GemmKernel{{}}, grids, blocks, 0, kargs)); - return ave_time; - - // ck_tile::index_t num_wgs_per_tile = kargs.tile_partitioner.estimate_num_wgs_per_tile(); - // return std::make_tuple(ave_time, num_wgs_per_tile); - }}; - - - if constexpr(ck_tile::StreamKReductionStrategy::Atomic == reduction_strategy) - {{ - return Run(ck_tile::integral_constant{{}}); - }} - else // We are using ck_tile::StreamKReductionStrategy::Reduction - {{ - return Run(ck_tile::integral_constant{{}}); - }} }} }}; """