From 47ae4b0955582432a667b713865f13ec48a634ed Mon Sep 17 00:00:00 2001 From: John Shumway Date: Mon, 23 Jun 2025 07:24:36 -0700 Subject: [PATCH 1/4] Shard several of the most costly targets. (#2373) * Shard several of the most costly targets. Introduces a filter_tuple_by_modulo to break up tuples. Drops build time of target from 21 minutes to under 14 minutes with 64 build processes, or 11 minutes with 128 build processes. time ninja -j 64 device_grouped_conv3d_fwd_instance * fix clang format * Fix build errors in instantiation code. I wasn't sure how to test the header-only instantiation code on my initial commit. From Jenkins CI test results, I see that there is a test target that depends on these headers: ninja -j 128 test_grouped_convnd_fwd This allowed me to test the build locally. I found three mistakes I made, mostly related to early experiments on I tried on the code. This was hard to find earlier because this PR is really too large. I also discovered that there are five 2D convolution targets that now dominate the compilation time. I will likely address those in a later PR, rather than adding even more changes to this PR. * Fix link errors from mismatched declarations. Our pattern for instantiating MIOpen templates uses duplicate declarations (instead of headers). This is fragile, and I didn't notice that my last commit had a bunch of link errors. I fixed these mistakes, and the bin/test_grouped_conv_fwd test target binary now links correctly. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Shard the longest 2D convolution builds Now that we have automated the shard instantiation, we can shard the 2D convolution targets that take the longest to build. The target test_grouped_conv2d_fwd now compiles in 15 minutes. * Use PROJECT_SOURCE_DIR for submodule compatibility I used CMAKE_SOURCE_DIR to refer to the top-level source directory in the ShardInstantiation.cmake file, but this can cause issues with git submodules. Instead, we should use PROJECT_SOURCE_DIR to ensure compatibility when this project is used as a submodule in another project. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Migrate the design to a code-generation approach. Use a CMake function with template files to generate the source files for the intantiating the kerenels and to generate the calling function. * Remove accidental copy of a file * Remove accidental copies of template files. --------- Co-authored-by: illsilin --- .gitignore | 3 + cmake/ShardInstantiation.cmake | 116 ++++++++++++++ cmake/call_shard.in | 15 ++ cmake/instantiate_shard.in | 9 ++ include/ck/utility/filter_tuple.hpp | 66 ++++++++ .../gpu/grouped_convolution_forward_xdl.inc | 3 +- .../gpu/grouped_conv2d_fwd/CMakeLists.txt | 51 ++++++- ...l_ngchw_gkcyx_ngkhw_bf16_comp_instance.in} | 38 ++--- ...wd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in} | 40 ++--- ...fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in} | 64 ++++---- ...gc_gkyxc_nhwgk_int8_mem_inter_instance.cpp | 66 -------- ...wgc_gkyxc_nhwgk_int8_mem_inter_instance.in | 80 ++++++++++ ...gc_gkyxc_nhwgk_int8_mem_intra_instance.cpp | 66 -------- ...wgc_gkyxc_nhwgk_int8_mem_intra_instance.in | 80 ++++++++++ .../gpu/grouped_conv3d_fwd/CMakeLists.txt | 109 +++++++++++-- ...dhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp | 111 -------------- ...ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in | 143 ++++++++++++++++++ ...ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp | 111 -------------- ..._ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in | 143 ++++++++++++++++++ ...gcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp | 54 ------- ...ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in | 65 ++++++++ ...ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp | 54 ------- ..._ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in | 63 ++++++++ ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp | 53 ------- ...xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in} | 53 ++++--- ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp | 53 ------- ..._xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in} | 53 ++++--- ...w_gkczyx_ngkdhw_bf16_mem_inter_instance.in | 64 ++++++++ ...w_gkczyx_ngkdhw_bf16_mem_intra_instance.in | 65 ++++++++ ...w_gkczyx_ngkdhw_f16_mem_inter_instance.in} | 69 +++++---- ...w_gkczyx_ngkdhw_f16_mem_intra_instance.in} | 75 ++++----- ...w_gkczyx_ngkdhw_f32_mem_inter_instance.in} | 69 +++++---- ...w_gkczyx_ngkdhw_f32_mem_intra_instance.in} | 69 +++++---- 33 files changed, 1346 insertions(+), 827 deletions(-) create mode 100644 cmake/ShardInstantiation.cmake create mode 100644 cmake/call_shard.in create mode 100644 cmake/instantiate_shard.in create mode 100644 include/ck/utility/filter_tuple.hpp rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/comp/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instance.in} (53%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in} (71%) rename library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/{device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp => device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in} (64%) delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/{mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in} (64%) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in} (57%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in} (59%) rename library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/{device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp => device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in} (59%) diff --git a/.gitignore b/.gitignore index 599ef99e35..e4dd8f7513 100644 --- a/.gitignore +++ b/.gitignore @@ -68,3 +68,6 @@ build*/ # Python cache __pycache__/ + +.cache/ + diff --git a/cmake/ShardInstantiation.cmake b/cmake/ShardInstantiation.cmake new file mode 100644 index 0000000000..47a5d0c48c --- /dev/null +++ b/cmake/ShardInstantiation.cmake @@ -0,0 +1,116 @@ +# Function to generate templated instantiation functions and caller function. + +# In order to reduce build times, we split the instantiation of template functions into multiple files. +# Developers can use ck::util::generate_sharded_instantiations to generate the instantiation functions, +# which can be placed the TEMPLATE_FILE (typically a .in file). + +# This CMake function generates the instantiation functions and a caller function that calls all the instantiation +# functions. The ck::util::generate_sharded_instantiations function allows us to generate an arbitrary number of +# shards (NUM_SHARDS). This function loops over the shards, generates an instantiation function for each shard, +# and generates a caller function that calls all the instantiation functions. + +# The explicit instatiation pattern requires the use of `extern template` to avoid implicit instantiation +# of the template functions in the caller function, and that code is automatically generated by this function. + +# In addition to the user-supplied template, this CMake function uses two generic templates: +# +# 1. `instantiate_shard.in`: This is the template for the instantiation functions. +# 2. `call_shard.in`: This is the template for the caller function that calls all the instantiation functions. + +# This function takes the following arguments: +# +# - INSTANCES_NAME: The name of the instances (the calling function will be named `add_${INSTANCE_NAMES}`). +# - TEMPLATE_FILE: The path to the template file that contains the templated instantiation function definitions. +# - NUM_SHARDS: The number of shards to generate. +# - OUTPUT_DIR: The build directory where the generated source files will be placed. +# - SRC_LIST: The list of source files to which the generated source files will be added. + + +function(generate_sharded_instantiations) + cmake_parse_arguments( + GEN_SHARDED + # No boolean arguments + "" + # Single-value arguments + "INSTANCES_NAME;TEMPLATE_FILE;NUM_SHARDS;OUTPUT_DIR;SRC_LIST" + # No multi-value arguments. + "" + ${ARGN} + ) + if (NOT GEN_SHARDED_INSTANCES_NAME) + message(FATAL_ERROR "INSTANCES_NAME is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_TEMPLATE_FILE) + message(FATAL_ERROR "TEMPLATE_FILE is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_NUM_SHARDS) + message(FATAL_ERROR "NUM_SHARDS is required for generate_sharded_instantiations") + endif() + if(NOT GEN_SHARDED_OUTPUT_DIR) + message(FATAL_ERROR "OUTPUT_DIR is required for generate_sharded_instantiations") + endif() + if (NOT GEN_SHARDED_SRC_LIST) + message(FATAL_ERROR "SRC_LIST is required for generate_sharded_instantiations") + endif() + + file(MAKE_DIRECTORY ${GEN_SHARDED_OUTPUT_DIR}) + + + set(GENERATED_SOURCE_FILES "") + set(EXTERN_TEMPLATE_STATEMENTS "") + set(CALL_STATEMENTS "") + message(STATUS "Generating sharded instantiations for target: ${GEN_SHARDED_INSTANCES_NAME}") + + set(INSTANCES "${GEN_SHARDED_INSTANCES_NAME}") + + # Generate the inc file with the template function defintions. + # This include file will hold the template function definitions and a using alias for all the shard + # instantiation functions. + configure_file( + "${GEN_SHARDED_TEMPLATE_FILE}" + "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.inc" + @ONLY + ) + + # Generate the sharded instantiation functions. + # This is where the build parallelization happens. + # Each of these source files will contain a single instantiation function for a shard, + # which will be called sequentially by the caller function. + set(INC_DIR "${GEN_SHARDED_INC_DIR}") + math(EXPR LAST_SHARD_ID "${GEN_SHARDED_NUM_SHARDS} - 1") + foreach(SHARD_ID RANGE 0 ${LAST_SHARD_ID}) + set(NUM_SHARDS "${GEN_SHARDED_NUM_SHARDS}") + set(SHARD_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}_shard_${SHARD_ID}.cpp") + set(SHARD_FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/instantiate_shard.in") + configure_file( + "${SHARD_FUNCTION_TEMPLATE}" + "${SHARD_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${SHARD_FUNCTION_PATH}") + set(SHARDED_FUNCTION_NAME "add_${INSTANCES}_shard<${NUM_SHARDS}, ${SHARD_ID}>") + list(APPEND EXTERN_TEMPLATE_STATEMENTS "extern template void\n${SHARDED_FUNCTION_NAME}(\n ${INSTANCES}& instances)") + list(APPEND CALL_STATEMENTS " ${SHARDED_FUNCTION_NAME}(instances)") + endforeach() + + # Join the include statements, the extern template declarations, and the call statements each + # into a single string for variable substitution in the caller function. + string(REPLACE ";" ";\n" INCLUDE_STATEMENTS "${INCLUDE_STATEMENTS}") + string(REPLACE ";" ";\n" CALL_STATEMENTS "${CALL_STATEMENTS}") + string(REPLACE ";" ";\n" EXTERN_TEMPLATE_STATEMENTS "${EXTERN_TEMPLATE_STATEMENTS}") + + # Generate the caller function. + set(CALLER_FUNCTION_PATH "${GEN_SHARDED_OUTPUT_DIR}/${INSTANCES}.cpp") + set(FUNCTION_TEMPLATE "${PROJECT_SOURCE_DIR}/cmake/call_shard.in") + configure_file( + "${FUNCTION_TEMPLATE}" + "${CALLER_FUNCTION_PATH}" + @ONLY + ) + list(APPEND GENERATED_SOURCE_FILES "${CALLER_FUNCTION_PATH}") + + # Add the generated source files to the list of source files. + # This allows the generated source files to be included in the build. + list(APPEND ${GEN_SHARDED_SRC_LIST} ${GENERATED_SOURCE_FILES}) + set(${GEN_SHARDED_SRC_LIST} "${${GEN_SHARDED_SRC_LIST}}" PARENT_SCOPE) +endfunction() \ No newline at end of file diff --git a/cmake/call_shard.in b/cmake/call_shard.in new file mode 100644 index 0000000000..daba79b055 --- /dev/null +++ b/cmake/call_shard.in @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { + +@EXTERN_TEMPLATE_STATEMENTS@; + +void add_@INSTANCES@( + @INSTANCES@& instances) { +@CALL_STATEMENTS@; +} + +} // namespace ck::tensor_operation::device::instance diff --git a/cmake/instantiate_shard.in b/cmake/instantiate_shard.in new file mode 100644 index 0000000000..dbc0af17a9 --- /dev/null +++ b/cmake/instantiate_shard.in @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "@INSTANCES@.inc" + +namespace ck::tensor_operation::device::instance { +template void add_@INSTANCES@_shard<@NUM_SHARDS@, @SHARD_ID@>( + @INSTANCES@& instances); +} // namespace ck::tensor_operation::device::instance diff --git a/include/ck/utility/filter_tuple.hpp b/include/ck/utility/filter_tuple.hpp new file mode 100644 index 0000000000..c2e378b879 --- /dev/null +++ b/include/ck/utility/filter_tuple.hpp @@ -0,0 +1,66 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include + +#include "ck/utility/functional.hpp" +#include "ck/utility/sequence.hpp" + +namespace ck::util { + +template +struct filter_tuple_by_modulo +{ + // Validate Stride and Offset. + static_assert(Stride > 0, "Offset must be positive."); + static_assert(Offset >= 0 && Offset < Stride, + "Offset must be positive and less than the stride."); + + // Generate filtered indices for this stride and offset. + static constexpr int new_size = (std::tuple_size_v + Stride - Offset - 1) / Stride; + + template + static constexpr auto to_index(std::index_sequence) + { + return std::index_sequence<(Offset + Is * Stride)...>{}; + } + + using filtered_indices = decltype(to_index(std::make_index_sequence{})); + + // Helper struct to construct the new tuple type from the filtered indices. + template + struct make_filtered_tuple_type_impl; + + template + struct make_filtered_tuple_type_impl> + { + using type = std::tuple...>; + }; + + using type = typename make_filtered_tuple_type_impl::type; +}; + +// Filter a tuple with a stride and offset. +// +// Tuple is a std::tuple or equivalent +// Stride is a positive integer +// Offset is a positive integer smaller than ofset +// +// Evaluates to a smaller tuple type from elements of T with stride M and offset I. +// +// Can be used to filter a tuple of types for sharded instantiations. +template +using filter_tuple_by_modulo_t = typename filter_tuple_by_modulo::type; + +// Example compile-time test: +// using OriginalTuple = +// std::tuple; +// using NewTuple_Every3rdFrom2nd = filter_tuple_by_modulo_t; +// static_assert(std::is_same_v>, +// "Test Case 1 Failed: Every 3rd from 2nd"); + +} // namespace ck::util diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc index b018737932..a3f2515099 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_forward_xdl.inc @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -688,7 +688,6 @@ void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( PassThrough, PassThrough, PassThrough>>>& instances); - void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instances( std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances_shard([[maybe_unused]] + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_comp_instances& instances) { add_device_operation_instances( instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in similarity index 71% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in index 4ca1b2b85e..88c84adfe2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances( +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances = std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances_shard( + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_bf16_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwdDefault>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<2, + ck::util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in index e3a12fd5f4..13fb583725 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances( +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances = std::vector>>& instances) + PassThrough>>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances_shard( + device_grouped_conv2d_fwd_xdl_ngchw_gkcyx_ngkhw_f16_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwdDefault>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1P0>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<2, - NGCHW, - GKCYX, - Empty_Tuple, - NGKHW, - ConvFwd1x1S1P0>{}); + ck::util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp deleted file mode 100644 index f667481fa4..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Interwave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Interwave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in new file mode 100644 index 0000000000..d8b35bda68 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instance.in @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances = + std::vector>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances_shard( + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp deleted file mode 100644 index 2ff2c7f51f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.cpp +++ /dev/null @@ -1,66 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdDefault, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1P0, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwd1x1S1P0, - Intrawave>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_int8_mem_instances<2, - NHWGC, - GKYXC, - Empty_Tuple, - NHWGK, - ConvFwdOddC, - Intrawave>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in new file mode 100644 index 0000000000..125e16139d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_fwd/xdl/mem/device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instance.in @@ -0,0 +1,80 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances = + std::vector>>; + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +template +void add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances_shard( + device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_int8_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); + + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_int8_mem_instances<2, + NHWGC, + GKYXC, + Empty_Tuple, + NHWGK, + ConvFwdOddC, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt index f8efa5a7c1..1d9d75a104 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/CMakeLists.txt @@ -11,8 +11,6 @@ set(GROUPED_CONV3D_FWD xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_int8_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp - xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16x16_instance.cpp xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_16x16_instance.cpp @@ -32,23 +30,13 @@ set(GROUPED_CONV3D_FWD xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_mem_intra_instance.cpp xdl/mem/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp - xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp - xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp + xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_comp_instance.cpp +xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_comp_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_2x_instance.cpp xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_part2_instance.cpp @@ -71,6 +59,99 @@ set(GROUPED_CONV3D_FWD wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_f16_oddc_instance.cpp wmma/device_grouped_conv3d_fwd_wmma_ndhwgc_gkzyxc_ndhwgk_i8_oddc_instance.cpp ) +# Add generated files for sharded instantiations. +include(ShardInstantiation) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in + NUM_SHARDS 8 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl +) +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances + TEMPLATE_FILE xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in + NUM_SHARDS 8 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl +) + +set(GENERATED_DIR ${CMAKE_CURRENT_BINARY_DIR}/generated) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances + TEMPLATE_FILE xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in + NUM_SHARDS 10 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/mem +) + +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) +generate_sharded_instantiations( + INSTANCES_NAME device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances + TEMPLATE_FILE xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in + NUM_SHARDS 12 + SRC_LIST GROUPED_CONV3D_FWD + OUTPUT_DIR ${GENERATED_DIR}/xdl/comp +) if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES) list(APPEND GROUPED_CONV3D_FWD diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp deleted file mode 100644 index a94f687ef8..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in new file mode 100644 index 0000000000..9d0eba6a6c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instance.in @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances = + std::vector>>; + +template +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp deleted file mode 100644 index 0c63345e7f..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.cpp +++ /dev/null @@ -1,111 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/host_utility/device_prop.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - - if(ck::get_device_name() != "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } - - if(ck::get_device_name() == "gfx950") - { - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwdDefault>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, - NDHWGC, - GKZYXC, - Empty_Tuple, - NDHWGK, - ConvFwd1x1S1P0>{}); - } -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in new file mode 100644 index 0000000000..ccabc2090a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instance.in @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances = + std::vector>>; + +template +void add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + + if(ck::get_device_name() != "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_part2<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + } + + if(ck::get_device_name() == "gfx950") + { + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances_2x<3, + NDHWGC, + GKZYXC, + Empty_Tuple, + NDHWGK, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); + } +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp deleted file mode 100644 index 43241454a5..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_bf16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in new file mode 100644 index 0000000000..4c67e4912c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp deleted file mode 100644 index d02d9f6778..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.cpp +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances( - instances, - device_grouped_conv_fwd_xdl_f16_comp_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in new file mode 100644 index 0000000000..0fbefa3bbc --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/comp/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instance.in @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_comp_instances& instances) +{ + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_comp_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp deleted file mode 100644 index 060eebebc1..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in index f3eccc7dc8..c87783eed9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.in @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances( +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instances& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp deleted file mode 100644 index 85b088f416..0000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { -// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances( - std::vector>>& instances) -{ - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault>{}); - - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0>{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0>{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in similarity index 64% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in index abea0bea81..ca6d571be1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.in @@ -1,15 +1,14 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances( +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_instances& instances) { - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + ConvFwdDefault>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); - add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, + ConvFwd1x1P0>, + Shards, + ShardIndex>{}); + + add_device_operation_instances( + instances, + util::filter_tuple_by_modulo_t{}); + ConvFwd1x1S1P0>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in new file mode 100644 index 0000000000..2586bc0f16 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.in @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in new file mode 100644 index 0000000000..7405f86a5f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instance.in @@ -0,0 +1,65 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" + +namespace ck::tensor_operation::device::instance { + +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances = + std::vector>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_intra_instances& instances) +{ + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances(instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); + add_device_operation_instances( + instances, + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_bf16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); +} + +} // namespace ck::tensor_operation::device::instance + diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in index ba5d9fb1de..24d6b66976 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_inter_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in similarity index 57% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in index fac3098341..91a2444241 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instance.in @@ -3,53 +3,60 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_bf16_mem_inter_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f16_mem_intra_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_bf16_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f16_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in index 5a2c4a0d5b..7571dff883 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_inter_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Interwave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Interwave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Interwave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in similarity index 59% rename from library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp rename to library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in index 701b8eb4a4..38ed240fab 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_fwd/xdl/mem/device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instance.in @@ -3,13 +3,11 @@ #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_mem_instance.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/filter_tuple.hpp" -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { +namespace ck::tensor_operation::device::instance { -void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances( +using device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances = std::vector>>& instances) + PassThrough>>>; +template +void add_device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances_shard( + device_grouped_conv3d_fwd_xdl_ngcdhw_gkczyx_ngkdhw_f32_mem_intra_instances& instances) { add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwdDefault, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwdDefault, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1P0, + Intrawave>, + Shards, + ShardIndex>{}); add_device_operation_instances(instances, - device_grouped_conv_fwd_xdl_f32_mem_instances<3, - NGCDHW, - GKCZYX, - Empty_Tuple, - NGKDHW, - ConvFwd1x1S1P0, - Intrawave>{}); + ck::util::filter_tuple_by_modulo_t< + device_grouped_conv_fwd_xdl_f32_mem_instances<3, + NGCDHW, + GKCZYX, + Empty_Tuple, + NGKDHW, + ConvFwd1x1S1P0, + Intrawave>, + Shards, + ShardIndex>{}); } -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck +} // namespace ck::tensor_operation::device::instance From dbfe70e72a5f2f0317b715cd4c7f7fb662affbe5 Mon Sep 17 00:00:00 2001 From: Rostyslav Geyyer <46627076+geyyer@users.noreply.github.com> Date: Mon, 23 Jun 2025 09:31:46 -0500 Subject: [PATCH 2/4] Add accelerated stochastic rounding on gfx950 (#2355) * Add native prand generation support for gfx950 * Update seed calculation --- include/ck/utility/amd_ck_fp8.hpp | 65 +++++++++++++--- include/ck/utility/mxf8_utils.hpp | 10 ++- include/ck/utility/type_convert.hpp | 114 ++++++++++++++++++---------- 3 files changed, 134 insertions(+), 55 deletions(-) diff --git a/include/ck/utility/amd_ck_fp8.hpp b/include/ck/utility/amd_ck_fp8.hpp index d079639c6a..cdc2a4fbda 100644 --- a/include/ck/utility/amd_ck_fp8.hpp +++ b/include/ck/utility/amd_ck_fp8.hpp @@ -5,6 +5,7 @@ #include "ck/ck.hpp" #include "ck/utility/enable_if.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/random_gen.hpp" #include "ck/utility/functional.hpp" #include "ck/utility/type.hpp" @@ -1396,12 +1397,18 @@ __host__ __device__ static inline fp8_storage_t cvt_float_to_fp8(const float f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } return cast_to_f8_from_f32( f, rng); @@ -1416,12 +1423,18 @@ __host__ static inline fp8_storage_t cvt_float_to_fp8(const float f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f); + rng = prand_generator(reinterpret_cast(&f), f); #else rng = prand_generator(reinterpret_cast(&f), f); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } if constexpr(interp == ck_fp8_interpretation_t::CK_E4M3_FNUZ) @@ -1487,12 +1500,18 @@ __device__ static inline fp8x2_storage_t cvt_float_to_fp8(const float2_t f) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - rng = prand_generator(reinterpret_cast(&f), f[0]); + rng = prand_generator(reinterpret_cast(&f), f[0]); #else rng = prand_generator(reinterpret_cast(&f), f[0]); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } return cast_to_f8_from_f32( f, rng); @@ -1532,12 +1551,18 @@ __host__ static inline fp8_storage_t cvt_half_t_to_fp8(const _Float16 x) uint32_t rng = 0; if constexpr(stochastic_rounding) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC rng = prand_generator(reinterpret_cast(&x), x); #else rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_f16(reinterpret_cast(&x), x[0]); #else rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_f16(reinterpret_cast(&x), static_cast(x)); #else rng = prand_generator(reinterpret_cast(&x), static_cast(x)); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_bf16(reinterpret_cast(&x), + rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); #else rng = prand_generator(reinterpret_cast(&x), static_cast(x[0])); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) } #if defined(__gfx950__) return cast_to_f8_from_bf16(reinterpret_cast(&f), f); + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); } return cast_to_f8_from_f32_scaled(f, rng, scale); } @@ -221,8 +222,9 @@ __host__ __device__ static inline fp8x2_storage_t cvt_float_to_fp8_scaled(const uint32_t rng = 0; if constexpr(stochastic_rounding) { - constexpr int seed = 1254739; - rng = prand_generator(reinterpret_cast(&f), f[0]); + // use HW clock for stochastic input multiply by incremented thread id + rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); } return cast_to_f8_from_f32_scaled(f, rng, scale); } diff --git a/include/ck/utility/type_convert.hpp b/include/ck/utility/type_convert.hpp index 5865f1dd78..2208a73860 100644 --- a/include/ck/utility/type_convert.hpp +++ b/include/ck/utility/type_convert.hpp @@ -5,6 +5,7 @@ #include "ck/utility/data_type.hpp" #include "ck/utility/f8_utils.hpp" +#include "ck/utility/get_id.hpp" #include "ck/utility/mxf4_utils.hpp" #include "ck/utility/mxf6_utils.hpp" #include "ck/utility/random_gen.hpp" @@ -234,12 +235,18 @@ __host__ __device__ constexpr Y f8_convert_sr(X x); template <> inline __host__ __device__ f8_fnuz_t f8_convert_sr(float x) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) #if defined(__gfx94__) union { @@ -296,12 +303,18 @@ inline __host__ __device__ f8_fnuz_t f8_convert_sr(half_t x) template <> inline __host__ __device__ bf8_fnuz_t f8_convert_sr(float x) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); +#else constexpr int seed = 1254739; #ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); + uint32_t rng = prand_generator(reinterpret_cast(&x), x); #else uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif +#endif // #ifndef CK_CODE_GEN_RTC +#endif // #if defined(__gfx950__) #if defined(__gfx94__) union { @@ -1446,13 +1459,10 @@ inline __host__ __device__ f4x32_t f4_convert_rne(float32_t x, float scale = 1.0 // convert fp32 to fp4 with stochastic rounding inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { uint32_t bitwise; @@ -1468,6 +1478,12 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) value.bitwise, float_values.float2_array, rng, scale, 0); return value.f4_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -1475,13 +1491,10 @@ inline __host__ __device__ f4_t f4_convert_sr(float x, float scale = 1.0f) // convert vector of 2 fp32 to vector of 2 fp4 with sr inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { uint32_t bitwise; @@ -1499,6 +1512,12 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) #endif // CK_WORKAROUND_FP32_TO_FP4_SR_CONVERSION return value.f4x2_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif union { uint32_t bitwise; @@ -1514,13 +1533,10 @@ inline __host__ __device__ f4x2_t f4_convert_sr(float2_t x, float scale = 1.0f) // convert vector of 32 fp32 to vector of 32 fp4 with sr inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { __uint128_t bitwise; @@ -1546,6 +1562,12 @@ inline __host__ __device__ f4x32_t f4_convert_sr(float32_t x, float scale = 1.0f return f4_values.f4x32_array; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x[0]); +#endif union { __uint128_t bitwise; @@ -1776,13 +1798,10 @@ inline __host__ __device__ f6x32_t f6_convert_rne(float32_t x, float scale = 1.0 */ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { float32_t float_vector; @@ -1799,6 +1818,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) return out.f6_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -1815,6 +1840,12 @@ inline __host__ __device__ f6_t f6_convert_sr(float x, float scale = 1.0f) */ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); + return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); +#else constexpr int seed = 1254739; union { @@ -1828,9 +1859,6 @@ inline __host__ __device__ f6x32_t f6_convert_sr(float32_t x, float scale = 1.0f uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif -#if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_sr_pk32_fp6_f32(x, rng, scale); -#else union { float32_t float_vector; @@ -2044,13 +2072,10 @@ inline __host__ __device__ bf6x32_t bf6_convert_rne(float32_t x, float scale = 1 */ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) { - constexpr int seed = 1254739; -#ifndef CK_CODE_GEN_RTC - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#else - uint32_t rng = prand_generator(reinterpret_cast(&x), x); -#endif #if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); union { float32_t float_vector; @@ -2067,6 +2092,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) return out.bf6_array[0]; #else + constexpr int seed = 1254739; +#ifndef CK_CODE_GEN_RTC + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#else + uint32_t rng = prand_generator(reinterpret_cast(&x), x); +#endif return utils::sat_convert_to_type_sr(x / scale, rng); #endif } @@ -2085,6 +2116,12 @@ inline __host__ __device__ bf6_t bf6_convert_sr(float x, float scale = 1.0f) */ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1.0f) { +#if defined(__gfx950__) + // use HW clock for stochastic input multiply by incremented thread id + uint32_t rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() * + (get_thread_global_1d_id() + 1)); + return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); +#else constexpr int seed = 1254739; union { @@ -2098,9 +2135,6 @@ inline __host__ __device__ bf6x32_t bf6_convert_sr(float32_t x, float scale = 1. uint32_t rng = prand_generator(reinterpret_cast(&x), float_values.float_array[0]); #endif -#if defined(__gfx950__) - return __builtin_amdgcn_cvt_scalef32_sr_pk32_bf6_f32(x, rng, scale); -#else union { float32_t float_vector; From b8212864cf569b347f26816bfd44a50cadd60e28 Mon Sep 17 00:00:00 2001 From: Yi DING Date: Tue, 24 Jun 2025 01:33:31 +0800 Subject: [PATCH 3/4] [CK_TILE] FMHA Support hdim_v to as a Multiple of 32 (#2114) * 160+192 * Add splitkv d160 * cleanup * fix * Add change log * Fix CHANGELOG * Use static_cast * Update ignored instance --------- Co-authored-by: asleepzzz --- CHANGELOG.md | 1 + example/ck_tile/01_fmha/README.md | 1 + .../ck_tile/01_fmha/codegen/ops/fmha_fwd.py | 45 +++++++-------- .../01_fmha/codegen/ops/fmha_fwd_splitkv.py | 5 +- example/ck_tile/01_fmha/fmha_fwd.cpp | 43 ++++----------- include/ck_tile/core/tensor/shuffle_tile.hpp | 7 ++- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 55 +++++++++++++++---- 7 files changed, 89 insertions(+), 68 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 368d1e502d..ab2076c0d8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added support for FP16 2:4 structured sparsity to universal GEMM. * Added support for Split K for grouped convolution backward data. * Added logit soft-capping support for fMHA forward kernels. +* Added support for hdim as a multiple of 32 for FMHA (fwd/fwd_splitkv) * Added benchmarking support for tile engine GEMM. * Added Ping-pong scheduler support for GEMM operation along the K dimension. * Added rotating buffer feature for CK_Tile GEMM. diff --git a/example/ck_tile/01_fmha/README.md b/example/ck_tile/01_fmha/README.md index 12414a20ed..72109a660b 100644 --- a/example/ck_tile/01_fmha/README.md +++ b/example/ck_tile/01_fmha/README.md @@ -71,6 +71,7 @@ args: -drop_seed seed for random number generator (default:1) -drop_offset offset for random number generator (default:0) -drop_prefs seed and offset values are present on GPU; 0 - host, 1 - device/GPU (default:0) + -num_splits number of splits for key/value. 0 to determine actual number by heuristic (default:1) -warmup number of iterations before benchmark the kernel (default:5) -repeat number of iterations to benchmark the kernel (default:20) ``` diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py index 7cbbdb9034..37a1b7329b 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd.py @@ -282,18 +282,19 @@ class FmhaFwdApiPool: # TODO: do we need to check duplication? if trait.dtype not in self.pool.keys(): self.pool[trait.dtype] = dict() - if trait.hdim not in self.pool[trait.dtype].keys(): - self.pool[trait.dtype][trait.hdim] = list() + hdim = trait.hdim, trait.bn1 + if hdim not in self.pool[trait.dtype].keys(): + self.pool[trait.dtype][hdim] = list() - self.pool[trait.dtype][trait.hdim].append(copy.copy(trait)) + self.pool[trait.dtype][hdim].append(copy.copy(trait)) @property def api(self) -> str: per_dtypes=str() for i, dtype in enumerate(self.pool.keys()): per_hdim_case=str() - for j, hdim in enumerate(self.pool[dtype].keys()): - traits=self.pool[dtype][hdim] + for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()): + traits=self.pool[dtype][(hdim, hdim_v)] inners=str() for k, trait in enumerate(traits): if_k = 'if' if k == 0 else 'else if' @@ -306,7 +307,7 @@ class FmhaFwdApiPool: F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype]) if_j = 'if' if j == 0 else 'else if' - per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners) + per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners) if_i = 'if' if i == 0 else 'else if' per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case) if not per_dtypes: @@ -435,18 +436,20 @@ class FmhaFwdKernel: def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: if dtype == 'fp16' or dtype == 'bf16': return { - '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - ### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (32, 32) : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (64, 64) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (96, 128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + (128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (160,160) : FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1), + (192,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), + ### (192,192) : FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1), + (256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': return { - '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), - '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (64,64 ) : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (128,128) : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), + (256,256) : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1), } else: return None @@ -454,7 +457,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]: # TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad # support this in future - def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]: + def get_pipelines(dtype, hdim, hdim_v) -> List[FmhaFwdPipeline]: # this function will populate a list possible pipelines # TODO: the order of List matters! the later in this list will be also be checked later # TODO: currently for qr pipeline, let 't' padding to appear later!! @@ -463,7 +466,7 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl pipelines = [] if dtype in ['fp16', 'bf16']: for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]): - if hdim == 256: + if hdim == 256 and hdim_v == 256: # if True: pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip)) @@ -507,15 +510,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl if d == None: continue #for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]): - for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()): - tile = d[hdim_str] - hdim = int(hdim_str) - for pipeline in get_pipelines(dtype, hdim): + for ((hdim, hdim_v), tile), mode in itertools.product(d.items(), MODE_MAP.keys()): + for pipeline in get_pipelines(dtype, hdim, hdim_v): if mode == "group": if pipeline.F_spad != 't' or pipeline.F_skpad != 't': # in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not continue - if hdim == 192 and tile.F_bn1 == 128: + if (hdim, hdim_v) == (192, 128) or hdim == 160: # NOTE: this is used to speedup deepseek prefill case, we don't gen training if pipeline.F_bias != 'no' or pipeline.F_dropout == 't': continue diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py index 3ae0e28be3..2d2d71555d 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_fwd_splitkv.py @@ -34,6 +34,7 @@ K0_MAX_SUBMAX_MAP = { 64 : 64, 96 : 128, 128: 128, + # 160: 160, 256: 256 } @@ -638,6 +639,7 @@ def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]: '64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), ### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), + ### '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -656,6 +658,7 @@ def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[d '64' : FmhaFwdSplitKVCombineTileSize(32, -1), ### '96' : FmhaFwdSplitKVCombineTileSize(32, -1), '128' : FmhaFwdSplitKVCombineTileSize(32, -1), + ### '160' : FmhaFwdSplitKVCombineTileSize(32, -1), '256' : FmhaFwdSplitKVCombineTileSize(32, -1), } elif dtype == 'fp8' or dtype == 'bf8': @@ -683,7 +686,7 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> if dtype in ['fp16', 'bf16']: for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]): # TODO: use async pipeline when compiler is more stable - if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128]: + if hdim == 256 or hdim in [32, 64, 128]: ### [32, 64, 96, 128, 160]: # if True: pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask)) diff --git a/example/ck_tile/01_fmha/fmha_fwd.cpp b/example/ck_tile/01_fmha/fmha_fwd.cpp index 8958c0c96e..972653c218 100755 --- a/example/ck_tile/01_fmha/fmha_fwd.cpp +++ b/example/ck_tile/01_fmha/fmha_fwd.cpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #include "fmha_fwd.hpp" #include "ck_tile/host.hpp" @@ -178,50 +178,30 @@ auto get_elimit(std::string init_method) } } -int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int max_splits) +int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int max_splits) { // If we have enough to almost fill the SMs, then just use 1 split if(batch_nhead_mblocks >= 0.8f * num_SMs) { return 1; } - max_splits = std::min({max_splits, num_SMs, num_n_blocks}); + max_splits = std::min({max_splits, num_SMs}); float max_efficiency = 0.f; std::vector efficiency; efficiency.reserve(max_splits); - auto ceildiv = [](int a, int b) { return (a + b - 1) / b; }; - // Some splits are not eligible. For example, if we have 64 blocks and choose 11 splits, - // we'll have 6 * 10 + 4 blocks. If we choose 12 splits, we'll have 6 * 11 + (-2) blocks - // (i.e. it's 11 splits anyway). - // So we check if the number of blocks per split is the same as the previous num_splits. - auto is_split_eligible = [&ceildiv, &num_n_blocks](int num_splits) { - return num_splits == 1 || - ceildiv(num_n_blocks, num_splits) != ceildiv(num_n_blocks, num_splits - 1); - }; for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) + float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; + float eff = n_waves / ceil(n_waves); + // printf("num_splits = %d, eff = %f\n", num_splits, eff); + if(eff > max_efficiency) { - efficiency.push_back(0.f); - } - else - { - float n_waves = float(batch_nhead_mblocks * num_splits) / num_SMs; - float eff = n_waves / ceil(n_waves); - // printf("num_splits = %d, eff = %f\n", num_splits, eff); - if(eff > max_efficiency) - { - max_efficiency = eff; - } - efficiency.push_back(eff); + max_efficiency = eff; } + efficiency.push_back(eff); } for(int num_splits = 1; num_splits <= max_splits; num_splits++) { - if(!is_split_eligible(num_splits)) - { - continue; - } if(efficiency[num_splits - 1] >= 0.85 * max_efficiency) { // printf("num_splits chosen = %d\n", num_splits); @@ -234,6 +214,7 @@ int num_splits_heuristic(int batch_nhead_mblocks, int num_SMs, int num_n_blocks, int override_num_splits_if_necessary( int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) { + (void)hdim_v; int device; auto status = hipGetDevice(&device); if(status != hipSuccess) @@ -250,15 +231,13 @@ int override_num_splits_if_necessary( // tile size should match the generate.py const int kM0 = 64; - const int kN1 = hdim_v; const int num_m_blocks = ck_tile::integer_divide_ceil(max_seqlen_q, kM0); - const int num_n_blocks = ck_tile::integer_divide_ceil(hdim_v, kN1); if(num_splits < 1 && p_drop == 0.0f) { return num_splits_heuristic( - batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); + batch * nhead * num_m_blocks, props.multiProcessorCount * 2, 128); } return num_splits; diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 55e3274cde..84c2b7d2fa 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -129,7 +129,10 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT // set output vectors static_for<0, num_vec_out, 1>{}([&](auto i) { constexpr auto idx_y_out_tmp = generate_array( - [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, + [&](auto ii) { + return ii == y_dim_vec_in ? static_cast(idx_y_start[ii]) + i + : static_cast(idx_y_start[ii]); + }, number{}); constexpr auto idx_y_out = diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index 30d07a4754..0b8e5836cd 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -787,12 +787,29 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); - static_assert(kKPack % K3 == 0); + constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K3 = total_pixels / N1; constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - if constexpr(get_warp_size() % (K2 * N0) == 0) + if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible + { + constexpr index_t kNPack = 32; + static_assert(kNPerBlock % kNPack == 0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = 2; + constexpr index_t N1_m = kNPack / N2; + constexpr index_t N0_m = kNPerBlock / kNPack; + constexpr index_t K1 = get_warp_size() / N1_m; + constexpr index_t K2_m = kKPerBlock / K1; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<1>, + tuple, sequence>, + tuple, sequence<2, 1>>, // K0, K1 N0 + tuple, sequence<1, 1>>, + sequence<1, 2, 1>, // N0 K2 N2 + sequence<0, 2, 2>>{}); + } + else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); @@ -860,12 +877,28 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy(); constexpr index_t N0 = kNPerBlock / N1; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; - static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); - static_assert(kKPack % K3 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemKPackV(); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave - if constexpr(get_warp_size() % (K2 * N0) == 0) + if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible + { + constexpr index_t kNPack = 32; + static_assert(kNPerBlock % kNPack == 0); + constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t N2 = 2; + constexpr index_t N1_m = kNPack / N2; + constexpr index_t N0_m = kNPerBlock / kNPack; + constexpr index_t K1 = get_warp_size() / N1_m; + constexpr index_t K2_m = kKPerBlock / K1; + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1>>, // K0, K1 N0 + tuple, sequence<1, 1>>, + sequence<1, 1, 2>, // N0 K2 <-> N2 + sequence<0, 2, 2>>{}); + } + else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0) { constexpr index_t K1 = get_warp_size() / (K2 * N0); constexpr index_t K0 = kBlockSize / get_warp_size(); From bb571a033019fd5a8ba6de31119395c3621a4235 Mon Sep 17 00:00:00 2001 From: lalala-sh Date: Tue, 24 Jun 2025 14:51:29 +0800 Subject: [PATCH 4/4] fix moe i4 bug from aiter (#2339) --- ...dlops_b_preshuffle_gufusion_dequant_v1.hpp | 50 ++++++++----------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp index 4f7b8e768c..29750b8baa 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_dequant_v1.hpp @@ -122,7 +122,6 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< using Base::B_K1; using Base::I0; using Base::I1; - using Base::KGroup; using Base::KRepeat; using Base::xdlops_gemm; using typename Base::HotLoopInstList; @@ -154,9 +153,9 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{}); constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{}); constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{}); - constexpr index_t K2 = KPack / KGroup; + constexpr index_t K2 = KPack; constexpr index_t K1 = 64 / NPerXDL; - constexpr index_t K0 = KRepeat * KGroup; + constexpr index_t K0 = KRepeat; return transform_tensor_descriptor( TileDesc_M0_M1_M2_K{}, @@ -291,14 +290,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< block_sync_lds(); static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -391,15 +388,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run( - a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -483,14 +477,12 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, KRepeat, 1>{}([&](auto k0) { - static_for<0, KGroup, 1>{}([&](auto kg0) { - a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, - make_tuple(m0, I0, I0, Number{}, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(m0, I0, I0, k0, I0, Number{}), - a_thread_buf); - }); + a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2, + make_tuple(m0, I0, I0, k0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(m0, I0, I0, k0, I0, I0), + a_thread_buf); }); }); // B VGPR->VGPR dequant @@ -596,7 +588,7 @@ struct BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_bdequant_v1< ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), - Sequence<1, 1, 1, 1, 1, KPack / KGroup>, + Sequence<1, 1, 1, 1, 1, KPack>, Sequence<0, 1, 2, 3, 4, 5>, 5, A_K1,