From b6e74be1aa38396609bca91cba5f9e5f8665e4b0 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 5 Nov 2024 08:53:10 -0800 Subject: [PATCH 01/24] Make sure cmake can handle the xnack+/xnack- targets. (#1633) * make sure cmake can handle xnack targets * dont build xdl instances for gfx906:xnack- * dont build xdl tests for gfx906:xnack- --- example/CMakeLists.txt | 8 ++++---- .../src/tensor_operation_instance/gpu/CMakeLists.txt | 10 +++++----- test/CMakeLists.txt | 12 ++++++------ 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index ad3f7c787f..22af7b2d5f 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -85,9 +85,9 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) @@ -169,9 +169,9 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) elseif(FILE_NAME MATCHES "_wmma") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) endif() set_source_files_properties(${FILE_NAME} PROPERTIES LANGUAGE HIP) add_executable(${EXAMPLE_NAME} ${FILE_NAME}) diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index 6756c33514..c8bbd6eb09 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -88,19 +88,19 @@ function(add_instance_library INSTANCE_NAME) foreach(source IN LISTS ARGN) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) if(source MATCHES "_xdl") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) elseif(source MATCHES "_wmma") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(source MATCHES "mha") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) endif() #only build the fp8 gemm instances for gfx908/90a if the build argument is set if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) endif() if(source MATCHES "gemm_multiply_multiply_f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) endif() endif() set(offload_targets) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b12ced5244..a81c5a96ba 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) elseif(ARGN MATCHES "_wmma") - list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) elseif(ARGN MATCHES "_wmma") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) From d0e3a70a2e3ebb8f979c82309e3e58b5c23fe865 Mon Sep 17 00:00:00 2001 From: darren-amd Date: Tue, 5 Nov 2024 12:59:08 -0500 Subject: [PATCH 02/24] Statically Cast Pointer Offset (#1631) * explicit cast ptr offset * formating change --- ...nv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp | 12 +++++----- ...conv_bwd_weight_two_stage_xdl_cshuffle.hpp | 24 +++++++++---------- ..._conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp | 12 +++++----- ...gridwise_gemm_multiple_d_wmma_cshuffle.hpp | 24 +++++++++---------- .../gpu/grid/gridwise_tensor_rearrange.hpp | 8 +++---- 5 files changed, 40 insertions(+), 40 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp index 5e9da459c0..b544c925e1 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp @@ -93,12 +93,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp index d3c0f84b9f..c1f58ccda5 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp @@ -60,12 +60,12 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -117,12 +117,12 @@ __global__ void const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * NumGroupsToMerge); const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.y * num_k_per_block); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp index 65b7b6cb7a..3e14f66a09 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp @@ -98,12 +98,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t c_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t c_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp index b3b057c80a..de6c9c1601 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp @@ -60,12 +60,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); @@ -155,12 +155,12 @@ __global__ void __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count); const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); - const long_index_t a_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const long_index_t b_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); - const long_index_t e_batch_offset = - amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + const long_index_t a_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const long_index_t b_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx))); + const long_index_t e_batch_offset = amd_wave_read_first_lane( + static_cast(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx))); const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp index 1740749907..ddf0b4a58d 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_tensor_rearrange.hpp @@ -121,10 +121,10 @@ struct GridwiseTensorRearrange __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch); // Global Memory - const index_t a_batch_offset = - __builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); - const index_t c_batch_offset = - __builtin_amdgcn_readfirstlane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)); + const index_t a_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx))); + const index_t c_batch_offset = __builtin_amdgcn_readfirstlane( + static_cast(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx))); const auto in_global_buf = make_dynamic_buffer( p_in_global + a_batch_offset, in_grid_desc.GetElementSpaceSize()); From 54440cf562b31eea6a158057fd8c41e9db1b4cc8 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 5 Nov 2024 13:56:20 -0800 Subject: [PATCH 03/24] remove gfx940;gfx941 from default target lists (#1640) --- CMakeLists.txt | 8 ++++---- Jenkinsfile | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 74628597af..bd2f606835 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -145,20 +145,20 @@ message("hip_version_flat=${hip_VERSION_FLAT}") message("checking which targets are supported") #In order to build just the CK library (without tests and examples) for all supported GPU targets -#use -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" +#use -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201" #the GPU_TARGETS flag will be reset in this case in order to avoid conflicts. # #In order to build CK along with all tests and examples it should be OK to set GPU_TARGETS to just 1 or 2 similar architectures. if(NOT ENABLE_ASAN_PACKAGING) if(NOT WIN32 AND ${hip_VERSION_FLAT} LESS 600300000) # WORKAROUND: compiler does not yet fully support gfx12 targets, need to fix version above - set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102") else() - set(CK_GPU_TARGETS "gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") + set(CK_GPU_TARGETS "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201") endif() else() #build CK only for xnack-supported targets when using ASAN - set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx940:xnack+;gfx941:xnack+;gfx942:xnack+") + set(CK_GPU_TARGETS "gfx908:xnack+;gfx90a:xnack+;gfx942:xnack+") endif() #if user set GPU_ARCHS on the cmake command line, overwrite default target list with user's list diff --git a/Jenkinsfile b/Jenkinsfile index 48b4c805cd..b79b2045b0 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -1101,11 +1101,11 @@ pipeline { agent{ label rocmnode("gfx90a") } environment{ setup_args = """ -DCMAKE_INSTALL_PREFIX=../install \ - -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ -DCMAKE_CXX_FLAGS=" -O3 " """ execute_args = """ cd ../client_example && rm -rf build && mkdir build && cd build && \ cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \ - -DGPU_TARGETS="gfx908;gfx90a;gfx940;gfx941;gfx942" \ + -DGPU_TARGETS="gfx908;gfx90a;gfx942" \ -DCMAKE_CXX_COMPILER="${build_compiler()}" \ -DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """ } @@ -1165,7 +1165,7 @@ pipeline { execute_args = """ cmake -D CMAKE_PREFIX_PATH=/opt/rocm \ -D CMAKE_CXX_COMPILER="${build_compiler()}" \ -D CMAKE_BUILD_TYPE=Release \ - -D GPU_ARCHS="gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ + -D GPU_ARCHS="gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1102" \ -D CMAKE_CXX_FLAGS=" -O3 " .. && make -j64 """ } steps{ From 365f39aed0d5335b6e39d5049231558128cfedd9 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Tue, 5 Nov 2024 14:58:29 -0700 Subject: [PATCH 04/24] Prevent instantiation of undefined FP8 operators. (#1639) --- .../elementwise_scale_permute_amax_2D_fp16_fp8.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp index 7ac3c4e239..9431a8cde4 100644 --- a/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp +++ b/example/44_elementwise_permute/elementwise_scale_permute_amax_2D_fp16_fp8.cpp @@ -68,7 +68,7 @@ using DeviceElementwisePermuteInstance = ck::tensor_operation::device::DeviceEle using DeviceReduceInstance = ck::tensor_operation::device::DeviceReduceMultiBlock& input, host_output_scaled_casted_transposed(m, k) = y1; const OutputDataType y_fabs = ck::type_convert(ck::math::abs(ck::type_convert(y0))); - host_output_amax(0) = ck::math::max(y_fabs, host_output_amax(0)); + host_output_amax(0) = ck::type_convert(ck::math::max( + ck::type_convert(y_fabs), ck::type_convert(host_output_amax(0)))); } } } From dcafb1de15a8fd1de3496f19fd806ac9cb185012 Mon Sep 17 00:00:00 2001 From: aledudek Date: Wed, 6 Nov 2024 10:44:58 +0100 Subject: [PATCH 05/24] Generic threshold calculation after merge fixes (#1618) * Generic threshold calculation add passing num of accums * Generic threshold - after merge fixes * Fix cmakelists --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> --- .../include/ck/library/utility/check_err.hpp | 8 ++++---- .../profiler/profile_pool3d_fwd_impl.hpp | 18 ++++++++++++++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/library/include/ck/library/utility/check_err.hpp b/library/include/ck/library/utility/check_err.hpp index 73ac2a189f..88741c3b96 100644 --- a/library/include/ck/library/utility/check_err.hpp +++ b/library/include/ck/library/utility/check_err.hpp @@ -24,7 +24,7 @@ namespace ck { namespace utils { template -double get_relative_threshold(const int numberOfAccumulations = 1) +double get_relative_threshold(const int number_of_accumulations = 1) { using F8 = ck::f8_t; using F16 = ck::half_t; @@ -79,13 +79,13 @@ double get_relative_threshold(const int numberOfAccumulations = 1) } else { - acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * numberOfAccumulations; + acc_error = std::pow(2, -NumericUtils::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } template -double get_absolute_threshold(const double max_possible_num, const int numberOfAccumulations = 1) +double get_absolute_threshold(const double max_possible_num, const int number_of_accumulations = 1) { using F8 = ck::f8_t; using F16 = ck::half_t; @@ -142,7 +142,7 @@ double get_absolute_threshold(const double max_possible_num, const int numberOfA else { acc_error = - std::pow(2, expo - NumericUtils::mant) * 0.5 * numberOfAccumulations; + std::pow(2, expo - NumericUtils::mant) * 0.5 * number_of_accumulations; } return std::max(acc_error, midway_error); } diff --git a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp index a0890028ac..cbdacad53b 100644 --- a/profiler/include/profiler/profile_pool3d_fwd_impl.hpp +++ b/profiler/include/profiler/profile_pool3d_fwd_impl.hpp @@ -240,6 +240,19 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& { out_device_buf.FromDevice(out_n_c_do_ho_wo_device.mData.data()); + auto number_of_accumulations = 1; + static_assert( + ReduceOpId == ck::ReduceTensorOp::AVG || ReduceOpId == ck::ReduceTensorOp::MAX, + "Warning: Unhandled ReduceOpId for setting up the number of accumulations!"); + + if constexpr(ReduceOpId == ck::ReduceTensorOp::AVG) + { + for(size_t i = 0; i < kernel_params.window_spatial_lengths.size(); ++i) + { + number_of_accumulations *= kernel_params.window_spatial_lengths.at(i); + } + } + auto absolute_error_threshold = 1.0; switch(in_params.init_method) { @@ -250,9 +263,10 @@ bool profile_pool3d_fwd_impl(PoolFwdInputParams& in_params, PoolFwdKernelParams& absolute_error_threshold = ck::utils::get_absolute_threshold( - absolute_error_threshold); + absolute_error_threshold, number_of_accumulations); auto relative_error_threshold = - ck::utils::get_relative_threshold(); + ck::utils::get_relative_threshold( + number_of_accumulations); bool pass = ck::utils::check_err(out_n_c_do_ho_wo_device.mData, out_n_c_do_ho_wo_host.mData, From 3599418aa8f6b19e94c09160a086030ed50c7184 Mon Sep 17 00:00:00 2001 From: rocking Date: Thu, 7 Nov 2024 03:32:44 +0800 Subject: [PATCH 06/24] Fix F16 type (#1583) --- profiler/src/profile_layernorm_fwd.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/profiler/src/profile_layernorm_fwd.cpp b/profiler/src/profile_layernorm_fwd.cpp index a261bd7418..7031b36531 100644 --- a/profiler/src/profile_layernorm_fwd.cpp +++ b/profiler/src/profile_layernorm_fwd.cpp @@ -85,7 +85,7 @@ int profile_layernorm(int argc, char* argv[]) if(data_type == ck::DataTypeEnum::Half) { - ck::profiler::profile_layernorm_impl( + ck::profiler::profile_layernorm_impl( do_verification, init_method, do_log, time_kernel, length); } else if(data_type == ck::DataTypeEnum::Float) From 75c5bfa3642cb368acae5c7824aa7d6c506f5dae Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 7 Nov 2024 14:14:42 -0800 Subject: [PATCH 07/24] enable compilation for generic navi targets (#1645) --- example/CMakeLists.txt | 4 ++-- include/ck/ck.hpp | 8 +++++--- include/ck/utility/amd_wmma.hpp | 5 +++-- include/ck_tile/core/config.hpp | 8 +++++--- library/src/tensor_operation_instance/gpu/CMakeLists.txt | 8 ++++---- test/CMakeLists.txt | 8 ++++---- 6 files changed, 23 insertions(+), 18 deletions(-) diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt index 22af7b2d5f..ea739c7071 100644 --- a/example/CMakeLists.txt +++ b/example/CMakeLists.txt @@ -85,7 +85,7 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) endif() @@ -169,7 +169,7 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME) #only continue if there are some source files left on the list if(FILE_NAME) if(FILE_NAME MATCHES "_xdl") - list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(FILE_NAME MATCHES "_wmma") list(REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) endif() diff --git a/include/ck/ck.hpp b/include/ck/ck.hpp index 5f74d51a65..999eb0229c 100644 --- a/include/ck/ck.hpp +++ b/include/ck/ck.hpp @@ -63,13 +63,15 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING) #define __gfx101__ #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ - defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \ + defined(__gfx10_3_generic__) #define __gfx103__ #endif -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx11_generic__) #define __gfx11__ #endif -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) #define __gfx12__ #endif diff --git a/include/ck/utility/amd_wmma.hpp b/include/ck/utility/amd_wmma.hpp index 322a0f94bb..d04513f3e8 100644 --- a/include/ck/utility/amd_wmma.hpp +++ b/include/ck/utility/amd_wmma.hpp @@ -9,7 +9,8 @@ // TODO: Add arch limitation namespace ck { -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx11_generic__) #define __gfx11__ #endif /********************************WAVE32 MODE***********************************************/ @@ -260,7 +261,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp> // gfx12 /********************************WAVE32 MODE***********************************************/ -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) #define __gfx12__ #endif diff --git a/include/ck_tile/core/config.hpp b/include/ck_tile/core/config.hpp index 4be50b8656..604c9551ff 100644 --- a/include/ck_tile/core/config.hpp +++ b/include/ck_tile/core/config.hpp @@ -11,13 +11,15 @@ #define __gfx94__ #endif #if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || \ - defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) + defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || \ + defined(__gfx10_3_generic__) #define __gfx103__ #endif -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) +#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || \ + defined(__gfx1103__) || defined(__gfx11_generic__) #define __gfx11__ #endif -#if defined(__gfx1200__) || defined(__gfx1201__) +#if defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx12_generic__) #define __gfx12__ #endif diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index c8bbd6eb09..80f0fc306b 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -88,19 +88,19 @@ function(add_instance_library INSTANCE_NAME) foreach(source IN LISTS ARGN) set(INST_TARGETS ${SUPPORTED_GPU_TARGETS}) if(source MATCHES "_xdl") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(source MATCHES "_wmma") list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(source MATCHES "mha") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() #only build the fp8 gemm instances for gfx908/90a if the build argument is set if(NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH) if(source MATCHES "gemm_xdl_universal" AND source MATCHES "f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() if(source MATCHES "gemm_multiply_multiply_f8") - list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack- gfx908:xnack+ gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() endif() set(offload_targets) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index a81c5a96ba..498a20dc55 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -64,11 +64,11 @@ function(add_test_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) @@ -141,11 +141,11 @@ function(add_gtest_executable TEST_NAME) #only continue if there are some source files left on the list if(ARGN) if(ARGN MATCHES "_xdl") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) elseif(ARGN MATCHES "_wmma") list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx908:xnack+ gfx908:xnack- gfx90a:xnack+ gfx90a:xnack- gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030) elseif(ARGN MATCHES "_smfmac") - list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201) + list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx906:xnack- gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a gfx1200 gfx1201 gfx10.3-generic gfx11-generic gfx12-generic) endif() set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP) add_executable(${TEST_NAME} ${ARGN}) From 686a58a912f6884a9b66841cf04b4b81ba35aa7f Mon Sep 17 00:00:00 2001 From: dummycoderfe Date: Fri, 8 Nov 2024 12:28:23 +0800 Subject: [PATCH 08/24] [Ck tile] layernorm2d fwd optimize (#1637) * optimze small N case using vec io and using rcp div * [Ck_tile] layernorm, add param to control fastdiv; change generate codes and test pass * [Ck_tile] fix blockSize compute in Generic2dBlockShape * [Ck_tile]fix kfastfdiv template style * [Ck_tile] layernorm, fix stype in review --------- Co-authored-by: dummycoderfe --- example/ck_tile/02_layernorm2d/generate.py | 105 ++++++++++-------- .../ops/common/generic_2d_block_shape.hpp | 12 +- ...ayernorm2d_fwd_pipeline_default_policy.hpp | 12 +- .../layernorm2d_fwd_pipeline_one_pass.hpp | 11 +- .../pipeline/layernorm2d_fwd_traits.hpp | 2 + .../ops/welford/block/block_welford.hpp | 34 ++++-- .../welford/block/block_welford_problem.hpp | 9 +- .../ops/welford/thread/thread_welford.hpp | 43 +++++-- 8 files changed, 144 insertions(+), 84 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index 09aa6b65f8..ca9e432a4f 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -57,6 +57,7 @@ template @@ -118,6 +119,7 @@ struct layernorm2d_fwd_traits_ static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kTwoPass = kTwoPass_; static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_; static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_; @@ -134,6 +136,7 @@ template @@ -148,6 +151,7 @@ using traits_ = layernorm2d_fwd_traits_; @@ -179,6 +183,7 @@ float layernorm2d_fwd_(const S& s, A a) using PipelineTraits = ck_tile::Layernorm2dFwdTraits(Traits_::kFusedAdd), static_cast(Traits_::kFusedQuant)>; @@ -269,7 +274,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, #include "layernorm2d_fwd_api_common.hpp" // clang-format off -// prec_i prec_o prec_sy rm rn tm tn vn pd mv 2p add sweep +// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf 2p add sweep {F_instance_def} // clang-format on @@ -356,6 +361,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, F_Vector_N : int F_kPadN : bool F_kSaveMeanInvStd_ : bool + F_kFastFDiv_ : bool F_kTwoPass_ : bool F_kFusedAdd : int F_kFusedQuant : int @@ -363,7 +369,7 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, @property def trait_name(self) ->str: t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_XScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}' - t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}' + t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}' t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}' return t_ @@ -483,52 +489,55 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t, fused_add_list = [0, 1] fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant - # rm rn tm tn vn pd mv 2p add sweep - h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, False, 0, 0)], - '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, False, 0, 0)], - '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, False, 0, 0)], - '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, False, 0, 0)], - '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, False, 0, 0)], - '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, False, 0, 0)], - '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, False, 0, 0)], - '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, False, 0, 0)], - '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, False, 0, 0)], - '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, False, 0, 0)], - '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, False, 0, 0)], - '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, False, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, False, 0, 0)], - 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, 0, 0), - h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, 0, 0)]} + # rm rn tm tn vn pd mv fdiv 2p add sweep + h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, False, 0, 0)], + '128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, False, 0, 0)], + '256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, False, 0, 0)], + '512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, False, 0, 0)], + '768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, False, 0, 0)], + '1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, False, 0, 0)], + '1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, False, 0, 0)], + '2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, False, 0, 0)], + '3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, False, 0, 0)], + '4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, False, 0, 0)], + '6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, False, 0, 0)], + '8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, False, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, False, 0, 0)], + 'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, 0, 0), + h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, 0, 0)]} total_blob = list() for hs_key in h_trait_dict: hs = h_trait_dict[hs_key] diff --git a/include/ck_tile/ops/common/generic_2d_block_shape.hpp b/include/ck_tile/ops/common/generic_2d_block_shape.hpp index 64ad20c3be..c0bfd93198 100644 --- a/include/ck_tile/ops/common/generic_2d_block_shape.hpp +++ b/include/ck_tile/ops/common/generic_2d_block_shape.hpp @@ -38,9 +38,7 @@ namespace ck_tile { template typename WarpPerBlock_, // num warps along seq typename WarpTile_, // warp size, seq - typename Vector_, // contiguous pixels(vector size) along seq - index_t BlockSize_ = - warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> + typename Vector_> // contiguous pixels(vector size) along seq)> struct Generic2dBlockShape { // block size @@ -68,10 +66,12 @@ struct Generic2dBlockShape static_assert(Warp_M % Vector_M == 0); static_assert(Warp_N % Vector_N == 0); // num of threads along seq, within each warp - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + static constexpr index_t ThreadPerBlock_M = Block_M / Repeat_M / Vector_M; + static constexpr index_t ThreadPerBlock_N = Block_N / Repeat_N / Vector_N; - static constexpr index_t BlockSize = BlockSize_; + static constexpr index_t BlockSize = ThreadPerBlock_M * ThreadPerBlock_N; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 1de230c144..724f6261d5 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -47,7 +47,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; return BlockWelford{}; } @@ -57,7 +58,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; return BlockWelfordSync{}; } @@ -67,7 +69,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; return BlockWelfordCrossWarpSync{}; } @@ -79,7 +82,8 @@ struct Layernorm2dFwdPipelineDefaultPolicy { using P_ = BlockWelfordProblem; + typename Problem::BlockShape, + Problem::Traits::kFastFDiv>; using block_welford = BlockWelford; using x_block_tile = diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 83cdab428e..4b83ed4fbf 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; @@ -125,7 +126,15 @@ struct Layernorm2dFwdPipelineOnePass // compute inv-std auto inv_std = tile_elementwise_in( [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ + epsilon)); + if(kFastFDiv && std::is_same_v) + { + return type_convert(1.0f) * + __builtin_amdgcn_rcpf(sqrt(v_ + epsilon)); + } + else + { + return type_convert(1.0f) / sqrt(v_ + epsilon); + } }, var); diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp index ed9e18be30..e8c22f8ab5 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -39,6 +39,7 @@ template<> struct Layernorm2dFusedQuantEnumName @@ -46,6 +47,7 @@ struct Layernorm2dFwdTraits { static constexpr bool kPadN = kPadN_; static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_; + static constexpr bool kFastFDiv = kFastFDiv_; static constexpr bool kTwoPass = kTwoPass_; static constexpr Layernorm2dFusedAddEnum kFusedAdd = kFusedAdd_; static constexpr Layernorm2dFusedQuantEnum kFusedQuant = kFusedQuant_; diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index ce73c183e1..968895e38e 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -11,9 +11,10 @@ namespace ck_tile { template struct BlockWelford { - using Problem = remove_cvref_t; - using XDataType = typename Problem::XDataType; - using ComputeDataType = typename Problem::ComputeDataType; + using Problem = remove_cvref_t; + using XDataType = typename Problem::XDataType; + using ComputeDataType = typename Problem::ComputeDataType; + static constexpr bool kFastFDiv = Problem::kFastFDiv; CK_TILE_DEVICE constexpr BlockWelford() {} @@ -89,7 +90,8 @@ struct BlockWelford template struct BlockWelfordSync { - using Problem = remove_cvref_t; + using Problem = remove_cvref_t; + static constexpr bool kFastFDiv = Problem::kFastFDiv; template CK_TILE_DEVICE void @@ -173,8 +175,9 @@ struct BlockWelfordSync template struct BlockWelfordCrossWarpSync { - using Problem = remove_cvref_t; - using BlockShape = typename Problem::BlockShape; + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + static constexpr bool kFastFDiv = Problem::kFastFDiv; template CK_TILE_DEVICE static constexpr index_t GetReduceWarps() @@ -351,12 +354,23 @@ CK_TILE_DEVICE constexpr index_t block_tile_welford_calculate_max_count(int row_ } // Note: this function must be called after all the computation -template +template CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTensor_& var_tensor, - int count) + int count, + bool_constant = {}) { using DataType = typename VarDistributedTensor_::DataType; - tile_elementwise_inout([&count](auto& x) { x = x / type_convert(count); }, - var_tensor); + tile_elementwise_inout( + [&count](auto& x) { + if(FastFdiv_ && std::is_same_v) + { + x = x * __builtin_amdgcn_rcpf(type_convert(count)); + } + else + { + x = x / type_convert(count); + } + }, + var_tensor); } } // namespace ck_tile diff --git a/include/ck_tile/ops/welford/block/block_welford_problem.hpp b/include/ck_tile/ops/welford/block/block_welford_problem.hpp index dcae1ef2ee..bcbfb7d76e 100644 --- a/include/ck_tile/ops/welford/block/block_welford_problem.hpp +++ b/include/ck_tile/ops/welford/block/block_welford_problem.hpp @@ -7,12 +7,13 @@ namespace ck_tile { -template +template struct BlockWelfordProblem { - using XDataType = remove_cvref_t; - using ComputeDataType = remove_cvref_t; - using BlockShape = remove_cvref_t; + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + static constexpr bool kFastFDiv = kFastFDiv_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/welford/thread/thread_welford.hpp b/include/ck_tile/ops/welford/thread/thread_welford.hpp index 4c61cdcf4b..52b253e5f7 100644 --- a/include/ck_tile/ops/welford/thread/thread_welford.hpp +++ b/include/ck_tile/ops/welford/thread/thread_welford.hpp @@ -7,25 +7,46 @@ namespace ck_tile { -template -CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count) +template +CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count, bool_constant = {}) { // TODO: check nan? maybe no T delta = x - mean; - mean += delta / count; + if(kFastFDiv && std::is_same_v) + { + mean += delta * __builtin_amdgcn_rcpf(count); + } + else + { + mean += delta / count; + } T delta2 = x - mean; var += delta * delta2; } -template -CK_TILE_DEVICE static void -welford_merge(T& mean_a, T& var_a, int& count_a, T mean_b, T var_b, int count_b) +template +CK_TILE_DEVICE static void welford_merge(T& mean_a, + T& var_a, + int& count_a, + T mean_b, + T var_b, + int count_b, + bool_constant = {}) { - int count = count_a + count_b; - T count_ = type_convert(count); - T count_a_ = type_convert(count_a); - T count_b_ = type_convert(count_b); - T count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + int count = count_a + count_b; + T count_ = type_convert(count); + T count_a_ = type_convert(count_a); + T count_b_ = type_convert(count_b); + T count_b_over_count; + if(kFastFDiv && std::is_same_v) + { + count_b_over_count = + count == 0 ? type_convert(0) : count_b_ * __builtin_amdgcn_rcpf(count_); + } + else + { + count_b_over_count = count == 0 ? type_convert(0) : count_b_ / count_; + } T delta = mean_b - mean_a; mean_a += delta * count_b_over_count; From ea3640fdea4b11178c1657feff4849ad011e5d26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Fri, 8 Nov 2024 10:04:33 +0100 Subject: [PATCH 09/24] Add generic instances for two stage conv bwd wei (#1643) * Add generic instances for two stage conv bwd wei * Update layout prefix --- ...conv_bwd_weight_two_stage_xdl_instance.hpp | 76 ++++++++++++- .../grouped_convolution_backward_weight.hpp | 16 +++ ...rouped_convolution_backward_weight_xdl.inc | 100 ++++++++++++++++++ .../grouped_conv2d_bwd_weight/CMakeLists.txt | 4 + ...ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp | 41 +++++++ ..._ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp | 41 +++++++ ...nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp | 41 +++++++ ...nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp | 2 +- ...nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp | 2 +- ..._nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp | 41 +++++++ ..._nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp | 2 +- ..._nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp | 2 +- .../grouped_conv3d_bwd_weight/CMakeLists.txt | 4 + ...wgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp | 41 +++++++ ...wgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp | 2 +- ...wgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp | 2 +- ...hwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp | 41 +++++++ ...hwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp | 2 +- ...hwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp | 2 +- ...dhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp | 41 +++++++ ...cdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp | 41 +++++++ 21 files changed, 534 insertions(+), 10 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp index 5f6c340e48..d82f82cce2 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp @@ -39,7 +39,25 @@ template -using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances = std::tuple< +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| @@ -64,7 +82,25 @@ template -using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances = std::tuple< +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1> + // clang-format on + >; + +template +using device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances = std::tuple< // clang-format off //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| @@ -82,6 +118,24 @@ using device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances = st // clang-format on >; +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, F16, F16, 1, 1> + // clang-format on + >; + // NGCHW requires transpose, we use vector loads and stores params for them template ; +template +using device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances = + std::tuple< + // clang-format off + //#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm| NumGroups| + //#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge| + //#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| MBlock_MPerBlock| NWaveNPerXdl| Scheduler| Version| | + //#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | | + DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, false, 1, 1, S<1, 8, 1, 8>, 1, Scheduler, PipelineVersion, 1, BF16, BF16, 1, 1> + // clang-format on + >; + template && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances( + op_ptrs); add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances( op_ptrs); add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev5_instances( @@ -403,6 +409,8 @@ struct DeviceOperationInstanceFactory && is_same_v) { + add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances( + op_ptrs); add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances( op_ptrs); add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev5_instances( @@ -464,6 +472,8 @@ struct DeviceOperationInstanceFactory && is_same_v && is_same_v) { + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( + op_ptrs); add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances( op_ptrs); add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev5_instances( @@ -524,6 +538,8 @@ struct DeviceOperationInstanceFactory && is_same_v) { + add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( + op_ptrs); add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances( op_ptrs); add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instances( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc index 132dde81ae..630eb81357 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_weight_xdl.inc @@ -113,6 +113,18 @@ void add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_f32_bf16_in PassThrough, PassThrough>>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_bf16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instances( std::vector>>& instances); +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instances( std::vector>>& instances); + +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( + std::vector>>& instances); + void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev2_instances( std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances< + 2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..d70c95bf6e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_ngchw_gkyxc_ngkhw_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances< + 2, + NGCHW, + GKYXC, + NGKHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..74ccc4c89b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp index 0e4d085de8..fab2898559 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev2_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp index 680494cfdf..407645e893 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_p // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..807de66ca5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances< + 2, + NHWGC, + GKYXC, + NHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp index 15401f0e1b..084c83cd65 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev2_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp index 398c14b11c..d174e5b6c0 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_weight/xdl/device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pi // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances< 2, NHWGC, GKYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt index c8c30897cf..cf4e323bfe 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/CMakeLists.txt @@ -15,6 +15,10 @@ set(GROUPED_CONV3D_BWD_WEIGHT xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev2_instance.cpp xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev5_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp + xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp ) if(DL_KERNELS) diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..63249a1c13 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_generic_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp index 549716586d..7841ddad99 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev2_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp index 18a00c6ea7..ba6285a380 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf1 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_bf16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_bf16_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..a8fbefb5bd --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_generic_instances< + 3, + NDHWGC, + GKZYXC, + NDHWGK, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp index 4d0f1e68cb..e4baafc0be 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev2_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp index c5cc062f2a..f9bc5b1349 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_instance.cpp @@ -25,7 +25,7 @@ void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16 // 1. Default add_device_operation_instances( instances, - device_grouped_conv_bwd_weight_two_stage_xdl_c_shuffle_f16_instances< + device_grouped_conv_bwd_weight_two_stage_nhwgc_xdl_c_shuffle_f16_instances< 3, NDHWGC, GKZYXC, diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp new file mode 100644 index 0000000000..16221eb3e7 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_bf16_generic_instances< + 3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp new file mode 100644 index 0000000000..126e90f2ce --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv3d_bwd_weight/xdl/device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instance.cpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_two_stage_xdl_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +// Compilation parameters for in[n, hi, wi, g, c] * wei[g, k, y, x, c] = out[n, ho, wo, g, k] +void add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ngcdhw_gkzyxc_ngkdhw_f16_pipev1_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_weight_two_stage_ngchw_xdl_c_shuffle_f16_generic_instances< + 3, + NGCDHW, + GKZYXC, + NGKDHW, + ConvBwdWeightDefault, + BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion::v1>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From af9546d9f4dba6945e23e1c346f92678f0f208f9 Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Sat, 9 Nov 2024 09:55:14 +0800 Subject: [PATCH 10/24] Fix 'sh' command compatibility of smoke_test_fwd.sh (#1553) --- .../ck_tile/01_fmha/script/smoke_test_fwd.sh | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh index 5dcc6ed42b..b867cd6c07 100755 --- a/example/ck_tile/01_fmha/script/smoke_test_fwd.sh +++ b/example/ck_tile/01_fmha/script/smoke_test_fwd.sh @@ -29,14 +29,14 @@ while getopts ":sa" opt; do done run_fp16_bf16_tests() { - local NUM_SPLITS=(1) - local PAGE_BLOCK_SIZE=(0) - local CACHE_BATCH_IDX=(0) + local NUM_SPLITS="1" + local PAGE_BLOCK_SIZE="0" + local CACHE_BATCH_IDX="0" if [ $TEST_SPLITKV -eq 1 ] ; then - NUM_SPLITS+=(2 3) - PAGE_BLOCK_SIZE+=(128) - CACHE_BATCH_IDX+=(1) + NUM_SPLITS="$NUM_SPLITS 2 3" + PAGE_BLOCK_SIZE="$PAGE_BLOCK_SIZE 128" + CACHE_BATCH_IDX="$CACHE_BATCH_IDX 1" fi for prec in "fp16" "bf16" ; do @@ -47,9 +47,9 @@ run_fp16_bf16_tests() { for lse in 0 1 ; do for bias in "n" "e" "a" ; do for p_drop in 0.0 0.2 ; do - for num_splits in "${NUM_SPLITS[@]}" ; do - for page_block_size in "${PAGE_BLOCK_SIZE[@]}" ; do - for cache_batch_idx in "${CACHE_BATCH_IDX[@]}" ; do + for num_splits in $NUM_SPLITS ; do + for page_block_size in $PAGE_BLOCK_SIZE ; do + for cache_batch_idx in $CACHE_BATCH_IDX ; do # $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS $EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS @@ -103,4 +103,4 @@ if [ $TEST_APPENDKV -eq 1 ] ; then run_fp16_appendkv_tests fi -set +x \ No newline at end of file +set +x From bec6fbc65fe766ab23fe563675703defdb0dd2be Mon Sep 17 00:00:00 2001 From: dummycoderfe Date: Sat, 9 Nov 2024 17:57:27 +0800 Subject: [PATCH 11/24] Ck tile/moe sorting (#1624) * add moe_sorting & check ok * fix comments & typo * Run remod.py under include/ck_tile & example/ck_tile directories * format codes * fix output ci check bug * fix moe sorting readme and error commit file * use magiv div to accelerate compute * add an loop unroll for moe lds ops * add extblocksnel to set zeros for moebufs * [Ck_tile] moe set zero run ok, add size check and fix ref check * [Ck_tile]fix moe_sorting fuse set_zero remod * [Ck_tile] change name style, fix zero buffer size err, change folder * [Ck_tile] moe_sorting: fix name style * [Ck_tile] moe_sorting, remove useless params in traits * [Ck_tile] change outputtile cnt * unit_size; change output buf alloc --------- Co-authored-by: dummycoderfe Co-authored-by: Po Yen, Chen Co-authored-by: carlushuang --- example/ck_tile/13_moe_sorting/CMakeLists.txt | 8 + example/ck_tile/13_moe_sorting/README.md | 27 ++ .../ck_tile/13_moe_sorting/moe_sorting.cpp | 223 +++++++++++++++++ .../13_moe_sorting/moe_sorting_api.cpp | 73 ++++++ .../13_moe_sorting/moe_sorting_api.hpp | 20 ++ .../13_moe_sorting/script/smoke_test.sh | 19 ++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/host.hpp | 1 + .../host/reference/reference_moe_sorting.hpp | 78 ++++++ .../fused_moe/kernel/moe_sorting_kernel.hpp | 232 ++++++++++++++++++ .../pipeline/moe_sorting_pipeline.hpp | 39 +++ .../fused_moe/pipeline/moe_sorting_policy.hpp | 15 ++ .../pipeline/moe_sorting_problem.hpp | 23 ++ include/ck_tile/ops/moe_sorting.hpp | 11 + 14 files changed, 770 insertions(+) create mode 100644 example/ck_tile/13_moe_sorting/CMakeLists.txt create mode 100644 example/ck_tile/13_moe_sorting/README.md create mode 100644 example/ck_tile/13_moe_sorting/moe_sorting.cpp create mode 100644 example/ck_tile/13_moe_sorting/moe_sorting_api.cpp create mode 100644 example/ck_tile/13_moe_sorting/moe_sorting_api.hpp create mode 100644 example/ck_tile/13_moe_sorting/script/smoke_test.sh create mode 100644 include/ck_tile/host/reference/reference_moe_sorting.hpp create mode 100644 include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp create mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp create mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp create mode 100644 include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp create mode 100644 include/ck_tile/ops/moe_sorting.hpp diff --git a/example/ck_tile/13_moe_sorting/CMakeLists.txt b/example/ck_tile/13_moe_sorting/CMakeLists.txt new file mode 100644 index 0000000000..09f3e4ac4e --- /dev/null +++ b/example/ck_tile/13_moe_sorting/CMakeLists.txt @@ -0,0 +1,8 @@ +add_executable(tile_example_moe_sorting EXCLUDE_FROM_ALL moe_sorting.cpp moe_sorting_api.cpp) +target_include_directories(tile_example_moe_sorting PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/) + +set(EXAMPLE_MOE_SORTING_COMPILE_OPTIONS) +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) +# list(APPEND EXAMPLE_MOE_SORTING_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) +target_compile_options(tile_example_moe_sorting PRIVATE ${EXAMPLE_MOE_SORTING_COMPILE_OPTIONS}) diff --git a/example/ck_tile/13_moe_sorting/README.md b/example/ck_tile/13_moe_sorting/README.md new file mode 100644 index 0000000000..7b6792dd95 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/README.md @@ -0,0 +1,27 @@ +# moe-sorting + +This folder contains example for moe-sorting kernel using ck_tile tile-programming implementation. This kernel is often used in Moe model, before launching the fused-moe-gemm block. The input&weight is a `token*topk` 2d matrix. The op rearange the input weight ids into different experts and feed into fuse moe gemm kernel. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_moe_sorting -j +``` +This will result in an executable `build/bin/tile_example_moe_sorting` + +## example +``` +args: + -v weather do CPU validation or not (default:1) + -pr_i index data type. (currently only fp32 supported now) (default:int32) + -pr_w output weight data type(currently only fp32 supported now) (default:fp32) + -t number of input tokens (default:32) + -e number of experts (default:8) + -k topk (default:2) + -st_i row stride of input, -1 means same as experts (default:-1) + -seed seed to be used, -1 means random every time (default:-1) + -kname when set to 1 it will print kernel name (default:0) + +``` diff --git a/example/ck_tile/13_moe_sorting/moe_sorting.cpp b/example/ck_tile/13_moe_sorting/moe_sorting.cpp new file mode 100644 index 0000000000..d2c4df1058 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting.cpp @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "moe_sorting_api.hpp" + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "weather do CPU validation or not") + .insert("pr_i", "int32", "index data type. (currently only int32 supported now)") + .insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)") + .insert("t", "128", "number of input tokens") + .insert("e", "8", "number of num_experts") + .insert("k", "4", "topk") + .insert("unit", "32", "unit_size") + .insert("moe_buf_size", "0", "moe_buf_size") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "when set to 1 it will print kernel name") + .insert("warmup", "5", "number of iterations before benchmark the kernel") + .insert("repeat", "20", "number of iterations to benchmark the kernel"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +void topid_unique_gen( + std::vector& host_tensor, int tokens, int topk, int num_expert, int seed) +{ + size_t total_size = topk * tokens; + std::srand(seed); + std::set unique_set; + IndexType current_v; + for(size_t i = 0; i < total_size; i++) + { + if(i % topk == 0) + { + unique_set.clear(); + } + current_v = std::rand() % num_expert; + while(unique_set.find(current_v) != unique_set.end()) + { + current_v = std::rand() % num_expert; + } + unique_set.insert(current_v); + host_tensor[i] = current_v; + } +} + +template +bool test_moe_sorting(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + int tokens = args.get_int("t"); + int num_experts = args.get_int("e"); + int topk = args.get_int("k"); + int seed = args.get_int("seed"); + int unit_size = args.get_int("unit"); + int moe_buf_size = args.get_int("moe_buf_size"); + int kname = args.get_int("kname"); + int warmup = args.get_int("warmup"); + int repeat = args.get_int("repeat"); + int max_output_ids = + ck_tile::integer_least_multiple(topk * tokens + num_experts * unit_size - topk, unit_size); + + if(seed < 0) + { + seed = std::time(nullptr); + } + + if(topk > num_experts) + { + printf("topk:%d value should be smaller than, or equal to number of num_experts:%d\n", + topk, + num_experts); + return false; + } + + // tokens already considered batch size + ck_tile::HostTensor topk_ids_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor weights_host({tokens, topk}, {topk, 1}); + ck_tile::HostTensor sorted_ids_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_host({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_host({max_output_ids / unit_size}, {1}); + ck_tile::HostTensor sorted_id_cnt_host({1}, {1}); + ck_tile::HostTensor moe_buf_host({moe_buf_size}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(weights_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(moe_buf_host); + topid_unique_gen(topk_ids_host.mData, tokens, topk, num_experts, seed); + + ck_tile::DeviceMem topk_ids_dev(topk_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem weights_dev(weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_ids_dev(sorted_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_weights_dev(sorted_weights_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_expert_ids_dev( + sorted_expert_ids_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem sorted_id_cnt_dev(sorted_id_cnt_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem moe_buf_dev(moe_buf_host.get_element_space_size_in_bytes()); + + topk_ids_dev.ToDevice(topk_ids_host.data()); + weights_dev.ToDevice(weights_host.data()); + if(moe_buf_size > 0) + { + moe_buf_dev.ToDevice(moe_buf_host.data()); + } + + moe_sorting_trait trait{index_prec, weight_prec}; + + moe_sorting_args karg{topk_ids_dev.GetDeviceBuffer(), + weights_dev.GetDeviceBuffer(), + sorted_ids_dev.GetDeviceBuffer(), + sorted_weights_dev.GetDeviceBuffer(), + sorted_expert_ids_dev.GetDeviceBuffer(), + sorted_id_cnt_dev.GetDeviceBuffer(), + moe_buf_size > 0 ? moe_buf_dev.GetDeviceBuffer() : nullptr, + tokens, + unit_size, + num_experts, + topk, + static_cast(moe_buf_size * sizeof(float))}; + + ck_tile::stream_config sc{nullptr, + true, + /* log_level = */ (kname ? 1 : 0), + warmup, + repeat}; + auto ms = moe_sorting(trait, karg, sc); + printf("[%s|%s]tokens:%d, num_experts:%d, topk:%d, ms:%f , ", + index_prec.c_str(), + weight_prec.c_str(), + tokens, + num_experts, + topk, + ms); + if(ms < 0) + printf("not supported\n"); + fflush(stdout); + if(ms < 0) + { + return false; + } + + sorted_ids_dev.FromDevice(sorted_ids_host.data()); + sorted_weights_dev.FromDevice(sorted_weights_host.data()); + sorted_expert_ids_dev.FromDevice(sorted_expert_ids_host.data()); + sorted_id_cnt_dev.FromDevice(sorted_id_cnt_host.data()); + if(moe_buf_size > 0) + { + moe_buf_dev.FromDevice(moe_buf_host.data()); + } + + bool rtn = true; + if(validate) + { + ck_tile::HostTensor sorted_ids_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_weights_ref({max_output_ids}, {1}); + ck_tile::HostTensor sorted_expert_ids_ref({max_output_ids / unit_size}, {1}); + + int32_t ref_total_tokens_post_pad = 0; + ck_tile::reference_moe_sorting(topk_ids_host, + weights_host, + sorted_ids_ref, + sorted_weights_ref, + sorted_expert_ids_ref, + ref_total_tokens_post_pad, + num_experts, + unit_size); + rtn &= ck_tile::check_err( + sorted_ids_host, sorted_ids_ref, std::string("OUT Error: Incorrect ids!"), 1e-6, 1e-6); + rtn &= ck_tile::check_err(sorted_weights_host, + sorted_weights_ref, + std::string("OUT Error: Incorrect w!"), + 1e-6, + 1e-6); + rtn &= ck_tile::check_err(sorted_expert_ids_host, + sorted_expert_ids_ref, + std::string("OUT Error: Incorrect eid!"), + 1e-6, + 1e-6); + if(moe_buf_size) + { + ck_tile::HostTensor moe_buf_ref({moe_buf_size}); + rtn &= ck_tile::check_err( + moe_buf_host, moe_buf_ref, std::string("OUT Error: Incorrect zero buf!"), 0, 0); + } + rtn &= ref_total_tokens_post_pad == sorted_id_cnt_host.mData[0]; + } + + printf("valid:%s\n", rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +int main(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return -1; + std::string index_prec = args.get_str("pr_i"); + std::string weight_prec = args.get_str("pr_w"); + + bool r = true; + if(weight_prec.compare("fp32") == 0 && index_prec.compare("int32") == 0) + { + r &= test_moe_sorting(args); + } + return r ? 0 : -1; +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp new file mode 100644 index 0000000000..25e99c5306 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.cpp @@ -0,0 +1,73 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "moe_sorting_api.hpp" + +#define MOE_SORTING_DISPATCH(unroll_num_) \ + constexpr ck_tile::index_t unroll_num = unroll_num_; \ + using ms_problem = ck_tile::MoeSortingProblem; \ + using kernel = ck_tile::MoeSortingKernel; \ + auto kargs = kernel::MakeKargs(a); \ + const dim3 grids = kernel::GridSize(a); \ + const dim3 blocks = kernel::BlockSize(a); \ + const auto lds_bytes = kernel::GetSmemSize(a); \ + float ave_time = ck_tile::launch_kernel( \ + s, ck_tile::make_kernel(kernel{}, grids, blocks, lds_bytes, kargs)); \ + return ave_time; + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s) +{ + if(t.weight_type == "fp32" && t.index_type == "int32") + { + if(a.num_experts > 127) + { + printf("lds size exceed, only support experts <127 \n"); + return -1; + } + if(a.moe_buf_bytes % 16) + { + printf("buf set size %d unaligned, must be multiple of 16\n", a.moe_buf_bytes); + return -1; + } + using index_t = ck_tile::index_t; + using ms_weight_type = float; + index_t smem_io_unroll_num = ck_tile::integer_divide_ceil(a.tokens * a.topk, 64); + switch(smem_io_unroll_num) + { + case(1): { + MOE_SORTING_DISPATCH(1); + } + case(2): { + MOE_SORTING_DISPATCH(2); + } + case(3): { + MOE_SORTING_DISPATCH(3); + } + case(5): { + MOE_SORTING_DISPATCH(5); + } + case(6): { + MOE_SORTING_DISPATCH(6); + } + case(7): { + MOE_SORTING_DISPATCH(7); + } + case(8): { + MOE_SORTING_DISPATCH(8); + } + case(9): { + MOE_SORTING_DISPATCH(9); + } + case(10): { + MOE_SORTING_DISPATCH(10); + } + case(11): { + MOE_SORTING_DISPATCH(11); + } + default: { + MOE_SORTING_DISPATCH(4); + } + } + } + return -1; +} diff --git a/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp new file mode 100644 index 0000000000..91b54932ce --- /dev/null +++ b/example/ck_tile/13_moe_sorting/moe_sorting_api.hpp @@ -0,0 +1,20 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once +#include +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/moe_sorting.hpp" + +struct moe_sorting_trait +{ + std::string index_type; + std::string weight_type; // currently always float +}; + +struct moe_sorting_args : public ck_tile::MoeSortingHostArgs +{ +}; + +float moe_sorting(moe_sorting_trait t, moe_sorting_args a, ck_tile::stream_config s); diff --git a/example/ck_tile/13_moe_sorting/script/smoke_test.sh b/example/ck_tile/13_moe_sorting/script/smoke_test.sh new file mode 100644 index 0000000000..1fc5eafcb0 --- /dev/null +++ b/example/ck_tile/13_moe_sorting/script/smoke_test.sh @@ -0,0 +1,19 @@ +# #!/bin/sh + +EXE=./build/bin/tile_example_moe_sorting + +$EXE -t=80 -e=17 -moe_buf_size=16 +$EXE -t=111 -e=117 -moe_buf_size=4 +$EXE -t=1000 -e=55 -moe_buf_size=1024 +$EXE -t=99 -e=120 -moe_buf_size=10244 +$EXE -t=175 -e=64 -k=8 +$EXE -t=65 -e=8 -k=2 +$EXE -t=1 -e=25 +$EXE -t=31 -e=19 -k=15 +$EXE -t=81 -e=37 -k=7 +$EXE -t=23 -e=1 -k=1 +$EXE -t=127 -e=99 -k=19 +$EXE -t=71 -e=11 -k=11 +$EXE -t=1 -e=1 -k=1 +$EXE -t=99 -e=2 -k=1 +$EXE -t=333 -e=99 -k=13 \ No newline at end of file diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index 9dd9a6ca3c..15db0f46c4 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -12,3 +12,4 @@ add_subdirectory(09_topk_softmax) add_subdirectory(10_rmsnorm2d) add_subdirectory(11_add_rmsnorm2d_rdquant) add_subdirectory(12_smoothquant) +add_subdirectory(13_moe_sorting) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index c0ab13ce3d..2e96009ace 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -23,6 +23,7 @@ #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" +#include "ck_tile/host/reference/reference_moe_sorting.hpp" #include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" #include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" diff --git a/include/ck_tile/host/reference/reference_moe_sorting.hpp b/include/ck_tile/host/reference/reference_moe_sorting.hpp new file mode 100644 index 0000000000..c8eb7edb55 --- /dev/null +++ b/include/ck_tile/host/reference/reference_moe_sorting.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +CK_TILE_HOST void reference_moe_sorting(const HostTensor& topk_ids, + const HostTensor& weights, + HostTensor& p_sorted_token_ids, + HostTensor& sorted_weight, + HostTensor& sorted_expert_ids, + index_t& unit_cnt, + const index_t experts, + const index_t unit_size) +{ + const index_t num_token = topk_ids.mDesc.get_lengths()[0]; + const index_t topk = topk_ids.mDesc.get_lengths()[1]; + std::vector> expert_tokens(experts, + std::vector(unit_size, num_token)); + std::vector> expert_token_weights( + experts, std::vector(unit_size, 0)); + std::vector expert_slices(experts, 1); + std::vector expert_slice_idxs(experts, 0); + + for(index_t t = 0; t < num_token; t++) + { + for(index_t k = 0; k < topk; k++) + { + IndexType e = topk_ids(t, k); + WeightType w = weights(t, k); + index_t idx = expert_slice_idxs[e]; + if(idx > expert_slices[e] * unit_size - 1) + { + expert_slices[e]++; + index_t new_size = expert_slices[e] * unit_size; + expert_tokens[e].resize(new_size); + expert_token_weights[e].resize(new_size); + for(index_t i = (expert_slices[e] - 1) * unit_size; i < new_size; i++) + { + expert_tokens[e][i] = num_token; + expert_token_weights[e][i] = 0; + } + } + + expert_tokens[e][idx] = t; + expert_token_weights[e][idx] = w; + expert_slice_idxs[e]++; + } + } + + IndexType* out_tokens = p_sorted_token_ids.data(); + WeightType* out_weights = sorted_weight.data(); + IndexType* out_expert_id = sorted_expert_ids.data(); + for(index_t e = 0; e < experts; e++) + { + memcpy(out_tokens, expert_tokens[e].data(), sizeof(index_t) * expert_slices[e] * unit_size); + out_tokens += expert_slices[e] * unit_size; + memcpy(out_weights, + expert_token_weights[e].data(), + sizeof(WeightType) * expert_slices[e] * unit_size); + out_weights += expert_slices[e] * unit_size; + + for(index_t s = 0; s < expert_slices[e]; s++) + { + out_expert_id[s] = e; + unit_cnt++; + } + out_expert_id += expert_slices[e]; + } + unit_cnt *= unit_size; + return; +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp new file mode 100644 index 0000000000..1c6acec70e --- /dev/null +++ b/include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp @@ -0,0 +1,232 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/elementwise.hpp" +#include "ck_tile/host/hip_check_error.hpp" +#include +#include + +namespace ck_tile { + +struct MoeSortingHostArgs +{ + const void* p_topk_ids; + const void* p_weights; + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_sorted_expert_ids; + void* p_total_tokens_post_pad; + void* p_moe_buf; + index_t tokens; + index_t unit_size; + index_t num_experts; + index_t topk; + index_t moe_buf_bytes; +}; + +template +struct MoeSortingKernel +{ + using Problem = remove_cvref_t; + + using IndexType = typename Problem::IndexType; + using WeightType = typename Problem::WeightType; + + typedef MoeSortingHostArgs MoeSortingKargs; + + using Hargs = MoeSortingHostArgs; + + struct Kargs + { + const void* p_topk_ids; + const void* p_weights; + void* p_sorted_token_ids; + void* p_sorted_weights; + void* p_sorted_expert_ids; + void* p_total_tokens_post_pad; + void* p_moe_buf; + index_t tokens; + index_t num_experts; + index_t moe_buf_bytes; + + index_t tokens_per_thread; + mdiv unit_size_mdiv; + mdiv topk_mdiv; + }; + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) + { + // TODO: assume num-experts not too much + return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16)); + } + + CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h) + { + return dim3(ck_tile::integer_least_multiple(h.num_experts, ck_tile::get_warp_size())); + } + + // in byte + CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h) + { + const auto blocks = BlockSize(h); + return ((blocks.x + 1) * h.num_experts + (h.num_experts + 1)) * sizeof(index_t); + } + + CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h) + { + Kargs k; + k.p_topk_ids = h.p_topk_ids; + k.p_weights = h.p_weights; + k.p_sorted_token_ids = h.p_sorted_token_ids; + k.p_sorted_weights = h.p_sorted_weights; + k.p_sorted_expert_ids = h.p_sorted_expert_ids; + k.p_moe_buf = h.p_moe_buf; + k.p_total_tokens_post_pad = h.p_total_tokens_post_pad; + k.tokens = h.tokens; + k.num_experts = h.num_experts; + k.moe_buf_bytes = h.moe_buf_bytes; + + const auto blocks = BlockSize(h); + k.tokens_per_thread = integer_divide_ceil(h.tokens * h.topk, blocks.x); + k.unit_size_mdiv = mdiv{static_cast(h.unit_size)}; + k.topk_mdiv = mdiv{static_cast(h.topk)}; + return k; + } + + CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const + { + return row * total_col + col; + } + + CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t* buf, index_t buf_bytes) const + { + const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x; + if(offset < buf_bytes / 16) + { + buf[offset] = uint8x16_t{0}; + } + } + + CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id, + const WeightType* __restrict__ weights, + index_t* p_sorted_token_ids, + WeightType* p_sorted_weights, + index_t* p_sorted_expert_ids, + index_t* p_total_tokens_post_pad, + const index_t num_experts, + const index_t tokens_per_thread, + const index_t numel, + const mdiv unit_size_mdiv, + const mdiv topk_mdiv, + void* smem) const + { + const index_t tid = static_cast(threadIdx.x); + const index_t start_idx = tid * tokens_per_thread; + + index_t* shared_mem = reinterpret_cast(smem); + + index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts) + index_t* cumsum = shared_mem + (blockDim.x + 1) * num_experts; // 1: (num_experts + 1) + for(int i = 0; i < num_experts; ++i) + { + tokens_cnts[calc_index(num_experts, tid + 1, i)] = 0; + } +#pragma unroll Problem_::InternalLoadUnroll + for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) + { + ++tokens_cnts[calc_index(num_experts, tid + 1, topk_id[i])]; + } + __syncthreads(); + + if(tid < num_experts) + { + tokens_cnts[calc_index(num_experts, 0, tid)] = 0; + for(int i = 1; i <= static_cast(blockDim.x); ++i) + { + tokens_cnts[calc_index(num_experts, i, tid)] += + tokens_cnts[calc_index(num_experts, i - 1, tid)]; + } + } + + // __syncthreads(); + if(tid == 0) + { + cumsum[0] = 0; + for(int i = 1; i <= num_experts; ++i) + { + auto current_units = [&]() { + index_t x_ = tokens_cnts[calc_index(num_experts, blockDim.x, i - 1)] + + unit_size_mdiv.divisor - 1; + index_t y_ = unit_size_mdiv.div(x_); + return max(y_, 1) * unit_size_mdiv.divisor; + }(); + cumsum[i] = cumsum[i - 1] + current_units; + } + *p_total_tokens_post_pad = cumsum[num_experts]; + } + __syncthreads(); + if(tid < num_experts) + { + for(int i = cumsum[tid]; i < cumsum[tid + 1]; i += unit_size_mdiv.divisor) + { + p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid; + } + } + +#pragma unroll Problem_::InternalLoadUnroll + for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) + { + index_t expert_id = topk_id[i]; + index_t rank_post_pad = + tokens_cnts[calc_index(num_experts, tid, expert_id)] + cumsum[expert_id]; + p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i); + p_sorted_weights[rank_post_pad] = weights[i]; + ++tokens_cnts[calc_index(num_experts, tid, expert_id)]; + } + + const index_t prefill_token = topk_mdiv.div(numel); + if(tid < num_experts) + { + index_t expert_offset = + cumsum[tid] + tokens_cnts[calc_index(num_experts, blockDim.x, tid)]; + while(expert_offset < cumsum[tid + 1]) + { + p_sorted_token_ids[expert_offset] = prefill_token; + p_sorted_weights[expert_offset] = static_cast(0.0); + expert_offset++; + } + } + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + if(blockIdx.x > 0) + { + if(kargs.p_moe_buf) + { + moe_buf_set_zero_kernel(reinterpret_cast(kargs.p_moe_buf), + kargs.moe_buf_bytes); + } + return; + } + const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor; + extern __shared__ char smem[]; + return moe_align_block_size_kernel(static_cast(kargs.p_topk_ids), + static_cast(kargs.p_weights), + static_cast(kargs.p_sorted_token_ids), + static_cast(kargs.p_sorted_weights), + static_cast(kargs.p_sorted_expert_ids), + static_cast(kargs.p_total_tokens_post_pad), + kargs.num_experts, + kargs.tokens_per_thread, + numel, + kargs.unit_size_mdiv, + kargs.topk_mdiv, + smem); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp new file mode 100644 index 0000000000..bbd47352d4 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp @@ -0,0 +1,39 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include +#include + +#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW +#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0 +#endif + +namespace ck_tile { + +// template +// struct MoeSortingPipeline +// { +// // TODO: this kernel only support warp per row +// using Problem = remove_cvref_t; +// using Policy = remove_cvref_t; +// using WeightType = typename Problem::WeightType; + +// template +// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window, +// const WeightWindow& weight_window, +// index_t* p_sorted_token_ids, +// WeightType* p_sorted_weights, +// index_t* p_sorted_expert_ids, +// index_t* p_total_tokens_post_pad, +// const index_t num_experts, +// const index_t unit_size, +// const size_t numel, +// const index_t topk) +// { +// } +// }; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp new file mode 100644 index 0000000000..f5218a93e2 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp @@ -0,0 +1,15 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/softmax.hpp" +#include "ck_tile/ops/topk.hpp" + +namespace ck_tile { + +struct MoeSortingPolicy +{ +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp new file mode 100644 index 0000000000..adde59e356 --- /dev/null +++ b/include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp @@ -0,0 +1,23 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include +#include + +namespace ck_tile { + +template +struct MoeSortingProblem +{ + // TODO: this kernel only support warp per row + using WeightType = remove_cvref_t; + using IndexType = remove_cvref_t; + + static constexpr index_t WarpSize = get_warp_size(); + static constexpr index_t WarpsPerBlock = 1; + static constexpr index_t InternalLoadUnroll = InternalLoadUnroll_; +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/moe_sorting.hpp b/include/ck_tile/ops/moe_sorting.hpp new file mode 100644 index 0000000000..b74607f061 --- /dev/null +++ b/include/ck_tile/ops/moe_sorting.hpp @@ -0,0 +1,11 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" From 13332998a4ca6dcc8cc5fcd401ca900529e5e65c Mon Sep 17 00:00:00 2001 From: Po Yen Chen Date: Mon, 11 Nov 2024 09:28:32 +0800 Subject: [PATCH 12/24] Return nullptr when block index is invalid (#1649) --- .../ck_tile/ops/fmha/block/page_block_navigator.hpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp index e8abdc579b..5d158f9fb3 100644 --- a/include/ck_tile/ops/fmha/block/page_block_navigator.hpp +++ b/include/ck_tile/ops/fmha/block/page_block_navigator.hpp @@ -230,7 +230,15 @@ struct PageBlockNavigator CK_TILE_HOST_DEVICE DataType* get_block_ptr(index_t block_index) const { - return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset; + if(block_index < num_blocks) + { + return physical_blocks + physical_block_indices[block_index] * block_stride + + fixed_offset; + } + else + { + return nullptr; + } } CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const From 8ef8a994e73370d69980a4df7377ed4ce8ed05c8 Mon Sep 17 00:00:00 2001 From: valarLip <103567126+valarLip@users.noreply.github.com> Date: Mon, 11 Nov 2024 16:02:28 +0800 Subject: [PATCH 13/24] [CK_TILE] add more stride for layernorm to support un-continuous Tensor (#1650) * [CK_TILE] add more stride for layernorm to support un-continuous Tensor * align CK coding style * extend strides to layernrom expample * clang-format... --- .../02_layernorm2d/layernorm2d_fwd.cpp | 63 ++++++++++++------- .../kernel/layernorm2d_fwd_kernel.hpp | 23 ++++--- 2 files changed, 56 insertions(+), 30 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp index 8f029c212c..b49c04619d 100644 --- a/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp +++ b/example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp @@ -25,7 +25,10 @@ auto create_args(int argc, char* argv[]) ck_tile::ArgParser arg_parser; arg_parser.insert("m", "3328", "m dimension") .insert("n", "4096", "n dimension") - .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("x_stride", "-1", "x row_stride, if -1 then equal to n") + .insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n") + .insert("y_stride", "-1", "y row_stride, if -1 then equal to n") + .insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n") .insert("e", "1e-5", "epsilon") .insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case") .insert("v", "1", "cpu validation or not") @@ -54,11 +57,20 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - ck_tile::index_t stride = arg_parser.get_int("stride"); - if(stride < 0) - stride = n; + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride"); + if(xr_stride < 0) + xr_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride"); + if(yr_stride < 0) + yr_stride = n; float epsilon = arg_parser.get_float("e"); std::string prec_i = arg_parser.get_str("prec_i"); std::string prec_o = arg_parser.get_str("prec_o"); @@ -89,7 +101,7 @@ bool run(const ck_tile::ArgParser& arg_parser) return false; } - assert(stride >= n); + assert(x_stride >= n); using TypeConfig = LayerNormTypeConfig; @@ -108,15 +120,15 @@ bool run(const ck_tile::ArgParser& arg_parser) using ComputeDataType = typename TypeConfig::ComputeDataType; // host verify - ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); ck_tile::HostTensor gamma_host({n}); ck_tile::HostTensor beta_host({n}); - ck_tile::HostTensor x_residual_host({m, n}, {stride, 1}); - ck_tile::HostTensor y_residual_host({m, n}, {stride, 1}); + ck_tile::HostTensor x_residual_host({m, n}, {xr_stride, 1}); + ck_tile::HostTensor y_residual_host({m, n}, {yr_stride, 1}); - ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); - ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {y_stride, 1}); ck_tile::HostTensor mean_host_ref({m}); ck_tile::HostTensor invStd_host_ref({m}); @@ -162,7 +174,9 @@ bool run(const ck_tile::ArgParser& arg_parser) }(); std::cout << "[" << prec_str << "]" - << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride + << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride + << ", yr_stride:" << yr_stride << std::flush; layernorm2d_fwd_traits traits{ prec_i, prec_o, prec_sx, prec_sy, SaveMeanVar, fused_add, fused_quant}; @@ -182,7 +196,10 @@ bool run(const ck_tile::ArgParser& arg_parser) epsilon, m, n, - stride}; + x_stride, // x row_stride + xr_stride, // x residule row stride + y_stride, // y row stride + yr_stride}; // y residule row stride float ave_time = layernorm2d_fwd( traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); @@ -285,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) y_buf.FromDevice(y_host_dev.data()); - ck_tile::HostTensor y_residual_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor y_residual_host_dev({m, n}, {yr_stride, 1}); if(fused_add == 1) { y_residual_buf.FromDevice(y_residual_host_dev.data()); @@ -293,7 +310,7 @@ bool run(const ck_tile::ArgParser& arg_parser) auto [rtol, atol] = get_elimit(); - if(stride == n) + if(x_stride == n) { pass = ck_tile::check_err( y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); @@ -310,10 +327,10 @@ bool run(const ck_tile::ArgParser& arg_parser) { for(int i_r = 0; i_r < m; i_r++) { - std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, - y_host_dev.begin() + i_r * stride + n); - std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, - y_host_ref.begin() + i_r * stride + n); + std::vector y_host_dev_row(y_host_dev.begin() + i_r * y_stride, + y_host_dev.begin() + i_r * y_stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * y_stride, + y_host_ref.begin() + i_r * y_stride + n); pass &= ck_tile::check_err(y_host_dev_row, y_host_ref_row, std::string("OUT[") + std::to_string(i_r) + @@ -323,10 +340,10 @@ bool run(const ck_tile::ArgParser& arg_parser) if(fused_add == 1) { std::vector y_residual_host_dev_row( - y_residual_host_dev.begin() + i_r * stride, - y_residual_host_dev.begin() + i_r * stride + n); + y_residual_host_dev.begin() + i_r * yr_stride, + y_residual_host_dev.begin() + i_r * yr_stride + n); std::vector y_residual_host_ref_row( - x_host.begin() + i_r * stride, x_host.begin() + i_r * stride + n); + x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n); pass &= ck_tile::check_err(y_residual_host_dev_row, y_residual_host_ref_row, std::string("ADD[") + std::to_string(i_r) + diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index f5a214ba57..10218e8084 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs index_t m; index_t n; - index_t stride; // row_stride + index_t x_stride; // x row_stride + index_t xr_stride; // x residule row stride + index_t y_stride; // y row stride + index_t yr_stride; // y residule row stride }; // TODO: Extract some type to wrapper class @@ -93,7 +96,10 @@ struct Layernorm2dFwd index_t m; index_t n; - index_t stride; // row_stride + index_t x_stride; // x row_stride + index_t xr_stride; // x residule row stride + index_t y_stride; // y row stride + index_t yr_stride; // y residule row stride }; using Hargs = Layernorm2dFwdHostArgs; @@ -112,7 +118,10 @@ struct Layernorm2dFwd hargs.epsilon, hargs.m, hargs.n, - hargs.stride}; + hargs.x_stride, + hargs.xr_stride, + hargs.y_stride, + hargs.yr_stride}; } CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) @@ -182,7 +191,7 @@ struct Layernorm2dFwd const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_x), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.x_stride, 1), number{}, number<1>{}); @@ -201,7 +210,7 @@ struct Layernorm2dFwd const auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_x_residual), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.xr_stride, 1), number{}, number<1>{}); @@ -250,7 +259,7 @@ struct Layernorm2dFwd auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_y), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.y_stride, 1), number{}, number<1>{}); @@ -266,7 +275,7 @@ struct Layernorm2dFwd auto tmp_ = make_naive_tensor_view( static_cast(kargs.p_y_residual), make_tuple(kargs.m, kargs.n), - make_tuple(kargs.stride, 1), + make_tuple(kargs.yr_stride, 1), number{}, number<1>{}); From 5fb150dbe700eba180feb5b27973a8ba95fae2ce Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Mon, 11 Nov 2024 09:25:08 -0800 Subject: [PATCH 14/24] restore collecting performance of mixed prec gemms (#1648) --- script/process_perf_data.py | 4 ++-- script/process_qa_data.sh | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/script/process_perf_data.py b/script/process_perf_data.py index b82a7c2891..3892206e42 100644 --- a/script/process_perf_data.py +++ b/script/process_perf_data.py @@ -133,12 +133,12 @@ def parse_logfile(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[4]) - elif 'onnx_gemm' in logfile or 'mixed_gemm' in logfile: + elif 'onnx_gemm' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() res.append(lst[33]) - elif 'splitK_gemm' in logfile: + elif 'splitK_gemm' in logfile or 'mixed_gemm' in logfile: for line in open(logfile): if 'Best Perf' in line: lst=line.split() diff --git a/script/process_qa_data.sh b/script/process_qa_data.sh index d6083d2fc7..c9a1645f6e 100755 --- a/script/process_qa_data.sh +++ b/script/process_qa_data.sh @@ -22,6 +22,7 @@ python3 process_perf_data.py perf_gemm_bilinear.log python3 process_perf_data.py perf_reduction.log python3 process_perf_data.py perf_splitK_gemm.log python3 process_perf_data.py perf_onnx_gemm.log +python3 process_perf_data.py perf_mixed_gemm.log file=./perf_fmha_fwd_gfx942.log if [ -e "$file" ]; then From 2b6458ddf243904cecf4c54b48c9dafa60ff80df Mon Sep 17 00:00:00 2001 From: Thomas Ning Date: Tue, 12 Nov 2024 10:08:25 +0800 Subject: [PATCH 15/24] [CK Tile] Improve the Layout, Padding, and Alignment features of CK Tile GEMM (#1651) * Finished the feature * Modified the test file * Test case update * addresss comment * Addressed the review comment * Fixed the CI error --- example/ck_tile/03_gemm/README.md | 3 + example/ck_tile/03_gemm/gemm_basic.cpp | 19 +- example/ck_tile/03_gemm/gemm_mem_pipeline.cpp | 10 +- include/ck_tile/core/tensor/shuffle_tile.hpp | 2 +- ...k_fmha_pipeline_qx_ks_vs_custom_policy.hpp | 2 + .../ck_tile/ops/gemm/kernel/gemm_kernel.hpp | 70 ++-- .../pipeline/gemm_pipeline_ag_bg_cr_mem.hpp | 6 +- .../gemm_pipeline_agmem_bgmem_creg_v1.hpp | 63 ++-- ...ine_agmem_bgmem_creg_v1_default_policy.hpp | 318 ++++++++++++++---- .../gemm/pipeline/gemm_pipeline_problem.hpp | 154 ++++++--- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 312 ++++++++++++++--- .../ops/gemm/pipeline/tile_gemm_traits.hpp | 16 +- .../gemm/test_gemm_mem_pipeline_util.hpp | 12 +- 13 files changed, 773 insertions(+), 214 deletions(-) diff --git a/example/ck_tile/03_gemm/README.md b/example/ck_tile/03_gemm/README.md index aacbdf6863..e9ffe72a91 100644 --- a/example/ck_tile/03_gemm/README.md +++ b/example/ck_tile/03_gemm/README.md @@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat mkdir build && cd build # you can replace with the appropriate architecture (for example gfx90a or gfx942) or leave it blank sh ../script/cmake-ck-dev.sh ../ +# The basic pipeline method on the gemm calculation make tile_example_gemm_basic -j +# The memory bound pipeline on the gemm calculation +make tile_example_gemm_mem_pipeline -j ``` This will result in an executable `build/bin/tile_example_gemm_basic` diff --git a/example/ck_tile/03_gemm/gemm_basic.cpp b/example/ck_tile/03_gemm/gemm_basic.cpp index 09427217c5..b7d8693442 100644 --- a/example/ck_tile/03_gemm/gemm_basic.cpp +++ b/example/ck_tile/03_gemm/gemm_basic.cpp @@ -17,10 +17,11 @@ template float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) { - // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + constexpr bool kTilePermute = false; // The rank and permutation will also be generate out by the CodeGen part. constexpr ck_tile::index_t kOutputRank = 2; @@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) CShuffleEpilogue, ck_tile::CShuffleEpilogue>, ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>>; + ck_tile::Default2DEpilogueProblem>>; using CodegenGemmTraits = - ck_tile::TileGemmTraits; + ck_tile::TileGemmTraits; using CodegenPipelineProblem = ck_tile:: GemmPipelineProblem; - using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; + using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy; using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; // ToDo: Will add the codegen part to test different pipeline policies in GEMM. diff --git a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp index 2ee0395e47..ff9d8bad32 100644 --- a/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp +++ b/example/ck_tile/03_gemm/gemm_mem_pipeline.cpp @@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) constexpr ck_tile::index_t K_Warp_Tile = 8; // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; + constexpr bool kPadM = true; + constexpr bool kPadN = true; + constexpr bool kPadK = true; constexpr int kBlockPerCu = 1; @@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) using TilePartitioner = ck_tile::GemmTilePartitioner; using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + ck_tile::Default2DEpilogueProblem>; - using Traits = ck_tile::TileGemmTraits; + using Traits = ck_tile::TileGemmTraits; using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::GemmPipelineProblem>; diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index da3c7117e5..55e3274cde 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in) } else { - // NOT implemented + static_assert(false, "The shuffle should always happen!"); } } diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index fbb05e1641..a3a29bb540 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy{}, number{}), - // somehow clang-format is splitting below line into multiple. - // clang-format off - sequence{}); + auto a_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + a_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); // clang-format on auto a_block_window = make_tile_window( @@ -128,12 +138,22 @@ struct GemmKernel make_tuple(number{}, number{}), {i_m, 0}); - auto b_pad_view = pad_tensor_view( - b_tensor_view, - make_tuple(number{}, number{}), - // clang-format off - sequence{}); - // clang-format on + auto b_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + b_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); auto b_block_window = make_tile_window( b_pad_view, @@ -171,18 +191,28 @@ struct GemmKernel } }(); - auto c_pad_view = pad_tensor_view( - c_tensor_view, - make_tuple(number{}, number{}), - // clang-format off - sequence{}); - // clang-format on - auto c_block_window = make_tile_window( + auto c_pad_view = [&]() { + if constexpr(std::is_same_v) + { + return pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + else + { + return pad_tensor_view( + c_tensor_view, + make_tuple(number{}, number{}), + sequence{}); + } + }(); + auto CBlockWindow_pad = make_tile_window( c_pad_view, make_tuple(number{}, number{}), {i_m, i_n}); - EpiloguePipeline{}(c_block_window, c_block_tile); + EpiloguePipeline{}(CBlockWindow_pad, c_block_tile); } }; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp index b9b45d3f42..85c5c58056 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp @@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr bool kPadA = Problem::kPadA; - static constexpr bool kPadB = Problem::kPadB; - static constexpr bool kPadC = Problem::kPadC; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; // Where is the right place for HasHotLoop and TailNum ??? static constexpr bool HasHotLoop = Problem::HasHotLoop; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp index a2424290e6..c0817e736b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp @@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1 static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeC = Problem::VectorSizeC; - static constexpr bool kPadA = Problem::kPadA; - static constexpr bool kPadB = Problem::kPadB; - static constexpr bool kPadC = Problem::kPadC; + static constexpr bool kPadM = Problem::kPadM; + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kPadK = Problem::kPadK; CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() { @@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1 Policy::template MakeADramTileDistribution()); // A LDS tile window for store - auto a_copy_lds_window = - make_tile_window(a_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - a_copy_dram_window.get_tile_distribution()); + auto a_copy_lds_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); // B DRAM tile window for load auto b_copy_dram_window = @@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1 Policy::template MakeBDramTileDistribution()); // B LDS tile window for store - auto b_copy_lds_window = - make_tile_window(b_lds_block, - make_tuple(number{}, number{}), - {0, 0}, - b_copy_dram_window.get_tile_distribution()); + auto b_copy_lds_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( @@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); // LDS write 0 - const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); - store_tile(a_copy_lds_window, a_block_tile_tmp); + if constexpr(std::is_same_v) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegBlockDescriptor()); + shuffle_tile(a_shuffle_tmp, a_block_tile); + const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); + store_tile(a_copy_lds_window, a_block_tile_tmp); + } + else + { + store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile)); + } // LDS write 0 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp, b_block_tile); + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } + else + { + store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile)); + } } index_t iCounter = num_loop - 1; @@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1 store_tile(a_copy_lds_window, a_block_tile_tmp); // LDS write i + 1 - const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); - store_tile(b_copy_lds_window, b_block_tile_tmp); + if constexpr(std::is_same_v) + { + auto b_shuffle_tmp_loop = make_static_distributed_tensor( + Policy::template MakeShuffledBRegBlockDescriptor()); + shuffle_tile(b_shuffle_tmp_loop, b_block_tile); + store_tile(b_copy_lds_window, + tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); + } + else + { + const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); + store_tile(b_copy_lds_window, b_block_tile_tmp); + } iCounter--; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp index 199ba56aac..c765b3ce9d 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp @@ -11,6 +11,7 @@ namespace ck_tile { // Default policy class should not be templated, put template on member functions instead struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy { + #if 0 // 2d template @@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy return smem_size; } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using ADataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BDataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(BDataType); + } #elif 1 // fake XOR template @@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; - constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; - constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(ADataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t M1 = kBlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = kMPerBlock / (M2 * M1); + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0)) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M1 = kMPerBlock / (M2 * M0); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); -#endif + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } } template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + constexpr index_t BlockSize = Problem::kBlockSize; + + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetSmemPackB(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + // coalesce reading for each warps + else + { + constexpr index_t N0 = BlockSize / get_warp_size(); + constexpr index_t N1 = NPerBlock / (N2 * N0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert(std::is_same_v); constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = 16 / sizeof(BDataType); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; -#if 1 // coalesce reading for each blocks - constexpr index_t N1 = kBlockSize / get_warp_size(); - static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = kNPerBlock / (N2 * N1); + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = kNPerBlock / N1; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * N0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * N0); + constexpr index_t K0 = kBlockSize / warp_size; - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); -#else // coalesce reading for each warps - constexpr index_t N0 = kBlockSize / get_warp_size(); - constexpr index_t N1 = kNPerBlock / (N2 * N0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<1, 1>>{}); -#endif + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = kMPerBlock / M1; + constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackA(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = kBlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = kBlockSize / get_warp_size() / K1; + static_assert(kKPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } } template diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp index 1156f549b6..3c43790bd6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp @@ -3,40 +3,133 @@ #pragma once -#include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" namespace ck_tile { -static constexpr int _VectorSize = 16; - template -struct GemmPipelineProblem +struct GemmPipelineProblemBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; + using GemmTraits = remove_cvref_t; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; - using GemmTraits = remove_cvref_t; using ALayout = remove_cvref_t; using BLayout = remove_cvref_t; using CLayout = remove_cvref_t; - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr bool kPadA = GemmTraits::kPadA; - static constexpr bool kPadB = GemmTraits::kPadB; - static constexpr bool kPadC = GemmTraits::kPadC; + static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize; + static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType); - static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType); - static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType); + static constexpr bool kPadM = GemmTraits::kPadM; + static constexpr bool kPadN = GemmTraits::kPadN; + static constexpr bool kPadK = GemmTraits::kPadK; + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() + { + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < VectorLoadSize / sizeof(ADataType) + ? pixels_per_thread + : VectorLoadSize / sizeof(ADataType); + } + else + { + return VectorLoadSize / sizeof(ADataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB() + { + if constexpr(std::is_same_v) + { + constexpr index_t pixels_per_thread = + BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize; + return pixels_per_thread < VectorLoadSize / sizeof(BDataType) + ? pixels_per_thread + : VectorLoadSize / sizeof(BDataType); + } + else + { + return VectorLoadSize / sizeof(BDataType); + } + } + + CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC() + { + if constexpr(std::is_same_v) + { + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size()); + constexpr index_t M0 = get_warp_size() / N2; + constexpr index_t M1 = BlockGemmShape::kM / M0; + + return std::min(M1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + else + { + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size()); + constexpr index_t N0 = get_warp_size() / M2; + constexpr index_t N1 = BlockGemmShape::kN / N0; + + return std::min(N1, static_cast(VectorLoadSize / sizeof(CDataType))); + } + } + + static constexpr index_t VectorSizeA = []() { + if constexpr(std::is_same_v) + { + return kPadK ? 1 : GetAlignmentA(); + } + else + { + return kPadM ? 1 : GetAlignmentA(); + } + }(); + + static constexpr index_t VectorSizeB = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentB(); + } + else + { + return kPadK ? 1 : GetAlignmentB(); + } + }(); + + static constexpr index_t VectorSizeC = []() { + if constexpr(std::is_same_v) + { + return kPadN ? 1 : GetAlignmentC(); + } + else + { + return kPadM ? 1 : GetAlignmentC(); + } + }(); }; +// Alias for GemmPipelineProblem +template +using GemmPipelineProblem = + GemmPipelineProblemBase; + template -struct UniversalGemmPipelineProblem +struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase { - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; - using GemmTraits = remove_cvref_t; - - using ALayout = remove_cvref_t; - using BLayout = remove_cvref_t; - using CLayout = remove_cvref_t; - - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; - static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); - - static constexpr bool kPadA = GemmTraits::kPadA; - static constexpr bool kPadB = GemmTraits::kPadB; - static constexpr bool kPadC = GemmTraits::kPadC; - - static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1; - static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1; - static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; }; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 7044a53140..207f1f9e4b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -9,12 +9,8 @@ namespace ck_tile { // UniversalGemm Policy -template struct UniversalGemmPipelineAgBgCrPolicy { - using LayoutA = remove_cvref_t; - using LayoutB = remove_cvref_t; - using LayoutC = remove_cvref_t; static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; @@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy TransposeC>; using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t K1 = WarpGemm::kK; constexpr index_t K0 = KPerBlock / K1; - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 ? 1 @@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t K1 = WarpGemm::kK; constexpr index_t K0 = KPerBlock / K1; - if constexpr(std::is_same::value) + if constexpr(std::is_same::value) { // NLdsLayer * K0 as logical Bank constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 @@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy return smem_size; } + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() + { + using ADataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(ADataType); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() + { + using BDataType = remove_cvref_t; + return Problem::VectorLoadSize / sizeof(BDataType); + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using WarpGemm = WarpGemmMfmaDispatcher; + using ADataType = remove_cvref_t; + using ALayout = remove_cvref_t; constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t M2 = get_warp_size() / K0; - - constexpr index_t M1 = BlockSize / get_warp_size(); - static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t M0 = MPerBlock / (M2 * M1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + if constexpr(std::is_same_v) + { + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t KPack = GetSmemPackA(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * M0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * M0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + if constexpr(get_warp_size() % (M2 * K0) == 0) + { + constexpr index_t M1 = BlockSize / get_warp_size(); + static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); + static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); + constexpr index_t M0 = MPerBlock / (M2 * M1); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + else + { + constexpr index_t M0 = BlockSize / get_warp_size(); + constexpr index_t M1 = MPerBlock / (M2 * M0); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } } template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using WarpGemm = WarpGemmMfmaDispatcher; + using BDataType = remove_cvref_t; + using BLayout = remove_cvref_t; constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; - constexpr index_t N2 = get_warp_size() / K0; + if constexpr(std::is_same_v) + { + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t KPack = GetSmemPackB(); + static_assert(KPack % K3 == 0); + constexpr index_t K2 = KPack / K3; + if constexpr(get_warp_size() % (K2 * N0) == 0) + { + constexpr index_t K1 = get_warp_size() / (K2 * N0); + constexpr index_t K0 = BlockSize / get_warp_size(); + static_assert(KPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<2, 1>, + sequence<3, 1>>{}); + } + } + else + { - constexpr index_t N1 = BlockSize / get_warp_size(); - static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); - static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); - constexpr index_t N0 = NPerBlock / (N2 * N1); + constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t K0 = KPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + if constexpr(get_warp_size() % (N2 * K0) == 0) + { + constexpr index_t N1 = BlockSize / get_warp_size(); + static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error."); + static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error."); + constexpr index_t N0 = NPerBlock / (N2 * N1); - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, - sequence<1, 2>, - sequence<0, 1>>{}); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + // coalesce reading for each warps + else + { + constexpr index_t N0 = BlockSize / get_warp_size(); + constexpr index_t N1 = NPerBlock / (N2 * N0); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<1, 1>>{}); + } + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() + { + using ALayout = remove_cvref_t; + using ADataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); + constexpr index_t M0 = MPerBlock / M1; + constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % M1 == 0); + constexpr index_t K3 = total_pixels / M1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * M0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * M0); + constexpr index_t K0 = BlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * M0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() + { + using BLayout = remove_cvref_t; + using BDataType = remove_cvref_t; + static_assert(std::is_same_v); + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); + constexpr index_t N0 = NPerBlock / N1; + constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; + static_assert(total_pixels % N1 == 0); + constexpr index_t K3 = total_pixels / N1; + constexpr index_t kKPack = GetSmemPackB(); + static_assert(kKPack % K3 == 0); + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t warp_size = get_warp_size(); + if constexpr(warp_size % (K2 * N0) == 0) + { + constexpr index_t K1 = warp_size / (K2 * N0); + constexpr index_t K0 = BlockSize / warp_size; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } + else + { + constexpr index_t K1 = (K2 * N0) / get_warp_size(); + constexpr index_t K2_m = K2 / K1; + constexpr index_t K0 = BlockSize / get_warp_size() / K1; + static_assert(KPerBlock == K0 * K1 * K2_m * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<0, 2>>, + sequence<1, 2>, + sequence<1, 3>>{}); + } } template diff --git a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp index 9d050be2fb..34756c3ff6 100644 --- a/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp +++ b/include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp @@ -3,19 +3,23 @@ #pragma once +#include "ck_tile/core.hpp" + namespace ck_tile { -template struct TileGemmTraits { - static constexpr bool kPadA = kPadA_; - static constexpr bool kPadB = kPadB_; - static constexpr bool kPadC = kPadC_; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool kPadK = kPadK_; + + static constexpr int _VectorSize = 16; using ALayout = ALayout_; using BLayout = BLayout_; diff --git a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp index 1b243ab437..6b47898339 100644 --- a/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_mem_pipeline_util.hpp @@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t K_Warp_Tile = 8; - constexpr bool kPadA = true; - constexpr bool kPadB = true; - constexpr bool kPadC = true; + constexpr bool kPadM = true; + constexpr bool kPadN = true; + constexpr bool kPadK = true; constexpr int kBlockPerCu = 1; @@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test using TilePartitioner = ck_tile::GemmTilePartitioner; using GemmEpilogue = ck_tile::Default2DEpilogue< - ck_tile::Default2DEpilogueProblem>; + ck_tile::Default2DEpilogueProblem>; - using Traits = ck_tile::TileGemmTraits; + using Traits = ck_tile::TileGemmTraits; using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::GemmPipelineProblem>; @@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test if(s.log_level_ > 0) { - std::cout << "Lunching kernel with args:" + std::cout << "Launching kernel with args:" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl; From 489c78d0735b7817859a22722e381f62f345cea7 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Tue, 12 Nov 2024 09:35:33 -0800 Subject: [PATCH 16/24] test rocm6.3 rc1 build 20 (#1659) --- Dockerfile | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index e2e2bc276f..791d1d9f3a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -24,10 +24,10 @@ RUN if [ "$ROCMVERSION" != "6.3" ]; then \ sh -c "echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] $DEB_ROCM_REPO focal main > /etc/apt/sources.list.d/rocm.list" && \ sh -c 'echo deb [arch=amd64 signed-by=/etc/apt/trusted.gpg.d/rocm-keyring.gpg] https://repo.radeon.com/amdgpu/$ROCMVERSION/ubuntu focal main > /etc/apt/sources.list.d/amdgpu.list'; \ elif [ "$ROCMVERSION" = "6.3" ] && [ "$compiler_version" = "rc1" ]; then \ - sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3.0.1-20.04-1_all.deb --no-check-certificate" && \ - apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3.0.1-20.04-1_all.deb && \ - sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3.0.1 rel-5 > /etc/apt/sources.list.d/rocm-build.list' && \ - amdgpu-repo --amdgpu-build=2033700; \ + sh -c "wget http://artifactory-cdn.amd.com/artifactory/list/amdgpu-deb/amdgpu-install-internal_6.3-20.04-1_all.deb --no-check-certificate" && \ + apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install dialog libpopt0 rsync && DEBIAN_FRONTEND=noninteractive apt-get install ./amdgpu-install-internal_6.3-20.04-1_all.deb && \ + sh -c 'echo deb [arch=amd64 trusted=yes] http://compute-artifactory.amd.com/artifactory/list/rocm-release-archive-20.04-deb/ 6.3 rel-20 > /etc/apt/sources.list.d/rocm-build.list' && \ + amdgpu-repo --amdgpu-build=2074281; \ fi RUN sh -c "echo deb http://mirrors.kernel.org/ubuntu focal main universe | tee -a /etc/apt/sources.list" From d20735691ccb9429ed66f42f831385c709707d62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bart=C5=82omiej=20Kocot?= Date: Wed, 13 Nov 2024 11:46:18 +0100 Subject: [PATCH 17/24] [CK TILE] Update gemm universal pipeline (#1644) * [CK TILE] Update gemm universal pipeline * Fixes * fix * Rebase --- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 377 +++++------------- 1 file changed, 105 insertions(+), 272 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 207f1f9e4b..94b0faf039 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -18,289 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy static constexpr bool TransposeC = true; + template + CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize() + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; + + if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) + { + return (16 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0) + { + return (8 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 && + sizeof(DataType) >= 4) + { + return (4 / sizeof(DataType)); + } + else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 && + sizeof(DataType) >= 2) + { + return (2 / sizeof(DataType)); + } + else + { + return 1; + } + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - using WarpGemm = WarpGemmMfmaDispatcher; using ADataType = remove_cvref_t; - using ALayout = remove_cvref_t; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; + constexpr index_t KPack = GetVectorLoadSize(); - if constexpr(std::is_same::value) - { - constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(ADataType); - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( - make_tuple(K0 * number{}, number{}, K1), - make_tuple(K1, number{}, I1)); + constexpr auto DataTypeSize = sizeof(ADataType); + constexpr auto MLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(K1)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(K0, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_ak0_kMLdsLayer_m_ak1, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); + constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + a_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - return a_lds_block_desc_m_k; - } - else // ColumnMajor A - { - // kfold and mpair dimension is not always required. - // more dimension in merge_transform increase the difficulty of generating immarg offset - // for compiler. - constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0); - constexpr auto M1 = MPerBlock / M0; + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - constexpr auto KThreadWrite = Problem::kBlockSize / M0; - constexpr auto K0PerThreadWrite = K0 / KThreadWrite; - constexpr auto KThreadRead = 64 / WarpGemm::kM; - constexpr auto K0PerThreadRead = K0 / KThreadRead; - - constexpr auto kfold = - (K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=mpair<=kN0 - constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128) - ? 1 - : ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0 - ? M0 - : 128 / (K1 * WarpGemm::kM * sizeof(ADataType))); - - constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - K1)); - - constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( - a_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor( - a_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( - a_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return a_lds_block_desc_m_k; - } + return a_lds_block_desc; } template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using WarpGemm = WarpGemmMfmaDispatcher; using BDataType = remove_cvref_t; - using BLayout = remove_cvref_t; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPack = GetVectorLoadSize(); - constexpr index_t K1 = WarpGemm::kK; - constexpr index_t K0 = KPerBlock / K1; + constexpr auto DataTypeSize = sizeof(BDataType); + constexpr auto NLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - if constexpr(std::is_same::value) - { - // NLdsLayer * K0 as logical Bank - constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 - ? 1 - : 32 * 4 / KPerBlock / sizeof(BDataType); - ; - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( - make_tuple(K0 * number{}, number{}, K1), - make_tuple(K1, number{}, I1)); + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(K1)), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple(make_unmerge_transform(make_tuple(K0, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + b_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_bk0_kNLdsLayer_n_bk1, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return b_lds_block_desc_n_k; - } - else // RowMajor B - { - constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1); - constexpr auto N1 = NPerBlock / N0; - - constexpr auto KThreadWrite = Problem::kBlockSize / N0; - constexpr auto K0PerThreadWrite = K0 / KThreadWrite; - constexpr auto KThreadRead = 64 / WarpGemm::kN; - constexpr auto K0PerThreadRead = K0 / KThreadRead; - - constexpr auto kfold = - (K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType)); - constexpr auto KThreadReadPerm = - (kfold * K0PerThreadWrite / K0PerThreadRead) > 1 - ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead) - : KThreadRead; - - // 1<=npair<=kN0 - constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128) - ? 1 - : ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0 - ? N0 - : 128 / (K1 * WarpGemm::kN * sizeof(BDataType))); - - constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed( - make_tuple(number{}, - number{}, - number{}, - number{}, - number{}, - K1)); - - constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( - b_lds_block_desc, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_xor_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}), - make_tuple( - sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{})); - - constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor( - b_lds_block_desc_permuted, - make_tuple( - make_pass_through_transform(number{}), - make_pass_through_transform(number{}), - make_unmerge_transform(make_tuple(number{}, number{})), - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(K1)), - make_tuple(sequence<0>{}, - sequence<1>{}, - sequence<2>{}, - sequence<3>{}, - sequence<4>{}, - sequence<5>{}), - make_tuple(sequence<1>{}, - sequence<2>{}, - sequence<0, 3>{}, - sequence<4, 5>{}, - sequence<6>{}, - sequence<7>{})); - - constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( - b_lds_block_desc_unmerged, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, - number{}, - number{}, - number{}, - K1)), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{}))), - make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}), - make_tuple(sequence<1>{}, sequence<0>{})); - - return b_lds_block_desc_n_k; - } + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + return b_lds_block_desc; } template @@ -330,20 +177,6 @@ struct UniversalGemmPipelineAgBgCrPolicy return smem_size; } - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() - { - using ADataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(ADataType); - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() - { - using BDataType = remove_cvref_t; - return Problem::VectorLoadSize / sizeof(BDataType); - } - template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { @@ -362,7 +195,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); constexpr index_t K3 = total_pixels / M1; - constexpr index_t KPack = GetSmemPackA(); + constexpr index_t KPack = GetVectorLoadSize(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; if constexpr(get_warp_size() % (K2 * M0) == 0) @@ -445,7 +278,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); constexpr index_t K3 = total_pixels / N1; - constexpr index_t KPack = GetSmemPackB(); + constexpr index_t KPack = GetVectorLoadSize(); static_assert(KPack % K3 == 0); constexpr index_t K2 = KPack / K3; if constexpr(get_warp_size() % (K2 * N0) == 0) @@ -530,7 +363,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % M1 == 0); constexpr index_t K3 = total_pixels / M1; - constexpr index_t kKPack = GetSmemPackB(); + constexpr index_t kKPack = GetVectorLoadSize(); static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t warp_size = get_warp_size(); @@ -578,7 +411,7 @@ struct UniversalGemmPipelineAgBgCrPolicy constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; static_assert(total_pixels % N1 == 0); constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemPackB(); + constexpr index_t kKPack = GetVectorLoadSize(); static_assert(kKPack % K3 == 0); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t warp_size = get_warp_size(); From 73f02a108347d626ee9b31789f0ff8b26ef87006 Mon Sep 17 00:00:00 2001 From: Taylor Ding Date: Wed, 13 Nov 2024 11:20:38 -0500 Subject: [PATCH 18/24] Move checks for compatibility from Argument() to IsSupportedArgument() (#1653) --- ..._grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 6bb5d431c9..17b7d962db 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -381,10 +381,6 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle { tildes = {i_ztilde, i_ytilde, i_xtilde}; } - else - { - throw std::runtime_error("wrong! only implemented for 2D and 3D now"); - } const auto a_grid_desc_ak0_m_ak1 = transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1( @@ -749,6 +745,12 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle return false; } } + + // check number of dimension, only implemented for 2D and 3D now + if(NDimSpatial != 2 && NDimSpatial != 3) + { + return false; + } return true; } From efd92615459c83d1af3f226f846b395323374a74 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Wed, 13 Nov 2024 09:20:18 -0800 Subject: [PATCH 19/24] fix clang format (#1662) --- .../device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp index 17b7d962db..3fb047f207 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp @@ -745,7 +745,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle return false; } } - + // check number of dimension, only implemented for 2D and 3D now if(NDimSpatial != 2 && NDimSpatial != 3) { From c1f8d53ce83c6ca6d15fec8d987974bc05008c16 Mon Sep 17 00:00:00 2001 From: feli Date: Thu, 14 Nov 2024 14:06:36 +0800 Subject: [PATCH 20/24] [Ck_tile] hot fix, fix rpcf param setting err (#1657) Co-authored-by: dummycoderfe --- .../pipeline/layernorm2d_fwd_pipeline_one_pass.hpp | 2 +- .../pipeline/layernorm2d_fwd_pipeline_two_pass.hpp | 14 +++++++++++--- .../ck_tile/ops/welford/block/block_welford.hpp | 13 +++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index 4b83ed4fbf..eefdaf9176 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass auto [mean, var] = block_welford(acc, cur_count, max_count); block_welford_sync(mean, var, cur_count); block_welford_cross_warp_sync(mean, var, cur_count, smem); - block_tile_welford_post_scale_var(var, cur_count); + block_tile_welford_post_scale_var(var, cur_count, constant{}); // compute inv-std auto inv_std = tile_elementwise_in( diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index fadf56dfd3..6a86cc43c9 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadN = Problem::Traits::kPadN; + static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; @@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass block_welford_sync(mean, var, cur_count); block_welford_cross_warp_sync(mean, var, cur_count, smem); - block_tile_welford_post_scale_var(var, cur_count); + block_tile_welford_post_scale_var(var, cur_count, constant{}); // compute inv-std auto inv_std = tile_elementwise_in( [&](const auto& v_) { - return type_convert(1.0f) / (sqrt(v_ + epsilon)); + if(kFastFDiv && std::is_same_v) + { + return type_convert(1.0f) * + __builtin_amdgcn_rcpf(sqrt(v_ + epsilon)); + } + else + { + return type_convert(1.0f) / sqrt(v_ + epsilon); + } }, var); - if constexpr(kSaveMean) store_tile(mean_window, cast_tile(mean)); if constexpr(kSaveInvStd) diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index 968895e38e..56ca86d9df 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -47,8 +47,11 @@ struct BlockWelford auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); - welford_update( - mean_tensor(out_dstr_idx), var_tensor(out_dstr_idx), x, cur_count_); + welford_update(mean_tensor(out_dstr_idx), + var_tensor(out_dstr_idx), + x, + cur_count_, + constant{}); }); } }); @@ -159,7 +162,8 @@ struct BlockWelfordSync v_local_count, v_remote_mean, v_remote_var, - v_remote_count); + v_remote_count, + constant{}); }); } }); @@ -307,7 +311,8 @@ struct BlockWelfordCrossWarpSync v_local_count, v_remote_mean, v_remote_var, - v_remote_count); + v_remote_count, + constant{}); }); mean_tensor.get_thread_buffer()(i_0) = v_local_mean; From d805a461aae7454de448bc0305cce01192fbc198 Mon Sep 17 00:00:00 2001 From: Andriy Roshchenko <107577548+andriy-ca@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:40:50 -0700 Subject: [PATCH 21/24] Fix example_convnd_fwd_max_xdl_int8 failures on MI300 (#1666) * Improve test verbosity. * BUGFIX: Add missing initialization for reduction buffer * Change default initialization method Performance may be affected for fp32 and int8 examples. * Improve test verbosity * Cleanup --- .../common.hpp | 2 +- .../run_convnd_fwd_max_example.inc | 57 +++++++++++++------ .../gemm_add_add_mean_meansquare_xdl_fp16.cpp | 2 +- 3 files changed, 43 insertions(+), 18 deletions(-) diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp index 7e3130a1a1..036f288d0a 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/common.hpp @@ -80,7 +80,7 @@ using RLayout = typename LayoutSettingSelector::RLayout; struct ExecutionConfig final { bool do_verification = true; - int init_method = 1; + int init_method = 2; bool time_kernel = false; }; diff --git a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc index cebfeb51d6..d61aee81a4 100644 --- a/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc +++ b/example/10_convnd_fwd_multiple_d_multiple_reduce/run_convnd_fwd_max_example.inc @@ -73,16 +73,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, Tensor conv_output_device(conv_output_g_n_k_wos_desc); Tensor r0_device(r0_desc); + std::cout << "input: " << conv_input.mDesc << std::endl; + std::cout << "weight: " << conv_weight.mDesc << std::endl; + std::cout << "output: " << conv_output_device.mDesc << std::endl; + std::cout << "reduction: " << r0_device.mDesc << std::endl << std::endl; + switch(config.init_method) { case 0: break; case 1: ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); - ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_weight); + ck::utils::FillUniformDistributionIntegerValue{-1, 1}(conv_weight); + break; + case 2: + ck::utils::FillUniformDistributionIntegerValue{-8, 7}(conv_input); + ck::utils::FillUniformDistribution{-1, 1}(conv_weight); break; default: - ck::utils::FillUniformDistribution{-5, 5}(conv_input); - ck::utils::FillUniformDistribution{-5, 5}(conv_weight); + ck::utils::FillUniformDistribution{-8, 7}(conv_input); + ck::utils::FillUniformDistribution{-1, 1}(conv_weight); } DeviceMem conv_input_device_buf(sizeof(ADataType) * conv_input.mDesc.GetElementSpaceSize()); @@ -161,15 +170,25 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, return false; } + // XXX: DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle will not initialize r0. + r0_device_buf.SetValue(ck::NumericLimits::Lowest()); + const float avg_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); - const std::size_t flop = problem_size.GetFlops(); - const std::size_t num_btype = problem_size.GetByte(); + if(config.time_kernel) + { + const std::size_t flop = problem_size.GetFlops(); + const std::size_t num_btype = problem_size.GetByte(); - const float tflops = static_cast(flop) / 1.E9 / avg_time; - const float gb_per_sec = num_btype / 1.E6 / avg_time; - std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << conv.GetTypeString() << std::endl; + const float tflops = static_cast(flop) / 1.E9 / avg_time; + const float gb_per_sec = num_btype / 1.E6 / avg_time; + std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << conv.GetTypeString() << std::endl; + } + else + { + std::cout << "FINISHED: " << conv.GetTypeString() << std::endl; + } if(config.do_verification) { @@ -189,6 +208,7 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, BElementOp{}, PassThrough{}); + std::cout << "\nRunning verification on CPU." << std::endl; ref_invoker.Run(ref_argument); Tensor r0_host(r0_device.mDesc); @@ -273,13 +293,18 @@ bool run_convnd_fwd_max(const ck::utils::conv::ConvParam& problem_size, conv_output_device_buf.FromDevice(conv_output_device.mData.data()); r0_device_buf.FromDevice(r0_device.mData.data()); - return ck::utils::check_err(conv_output_device, - conv_output_host, - "Error: incorrect results! (Matrix E)", - 1e-5f, - 1e-4f) && - ck::utils::check_err( - r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-5f, 1e-4f); + auto pass = ck::utils::check_err(conv_output_device, + conv_output_host, + "Error: incorrect results! (Matrix E)", + 1e-3f, + 1e-3f); + pass = + pass && ck::utils::check_err( + r0_device, r0_host, "Error: incorrect results! (Matrix R0)", 1e-3f, 1e-3f); + if(pass) + std::cout << "Verification on CPU: PASS" << std::endl; + + return pass; } return true; diff --git a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp index 2f6533d448..a46eaa4816 100644 --- a/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp +++ b/example/16_gemm_multi_d_multi_reduces/gemm_add_add_mean_meansquare_xdl_fp16.cpp @@ -198,7 +198,7 @@ int main() throw std::runtime_error("wrong! this device_op instance does not support this problem"); } - // init reducetion buffer to 0 + // init reduction buffer to 0 r0_device_buf.SetZero(); r1_device_buf.SetZero(); From 3b6a481e92d8ba2a9f9e87136678b05bcaf573a7 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:14:50 -0800 Subject: [PATCH 22/24] re-enable coerce-illegal-types flag for rocm6.3 (#1668) --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index bd2f606835..4bb69300a6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -221,7 +221,7 @@ if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600140090) endif() set(check-coerce) check_cxx_compiler_flag(" -mllvm -amdgpu-coerce-illegal-types=1" check-coerce) -if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132 AND ${hip_VERSION_FLAT} LESS 600300000) +if(NOT WIN32 AND check-coerce AND ${hip_VERSION_FLAT} GREATER 600241132) message("Adding the amdgpu-coerce-illegal-types=1") add_compile_options("SHELL: -mllvm -amdgpu-coerce-illegal-types=1") endif() From b4a79045829b07f7e80603fb773c196e1f7a7214 Mon Sep 17 00:00:00 2001 From: Illia Silin <98187287+illsilin@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:15:01 -0800 Subject: [PATCH 23/24] re-enable fp8 gemms in ckProfiler (#1667) --- CMakeLists.txt | 6 ++++-- profiler/src/profile_gemm_universal.cpp | 6 +++--- test/gemm_universal/test_gemm_universal_xdl.cpp | 4 ++-- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bb69300a6..b28a6d9127 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,12 +183,14 @@ message("Building CK for the following targets: ${SUPPORTED_GPU_TARGETS}") if (SUPPORTED_GPU_TARGETS MATCHES "gfx9") message("Enabling XDL instances") add_definitions(-DCK_USE_XDL) - set(CK_USE_XDL "ON") +endif() +if (SUPPORTED_GPU_TARGETS MATCHES "gfx94") + message("Enabling FP8 gemms in ckProfiler") + add_definitions(-DCK_USE_GFX94) endif() if (SUPPORTED_GPU_TARGETS MATCHES "gfx11" OR SUPPORTED_GPU_TARGETS MATCHES "gfx12") message("Enabling WMMA instances") add_definitions(-DCK_USE_WMMA) - set(CK_USE_WMMA "ON") endif() option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF) if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908")) diff --git a/profiler/src/profile_gemm_universal.cpp b/profiler/src/profile_gemm_universal.cpp index 576bd009b6..990cbd292e 100644 --- a/profiler/src/profile_gemm_universal.cpp +++ b/profiler/src/profile_gemm_universal.cpp @@ -101,7 +101,7 @@ int profile_gemm_universal(int argc, char* argv[]) using F32 = float; using F16 = ck::half_t; using BF16 = ck::bhalf_t; -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) using F8 = ck::f8_t; #endif @@ -164,7 +164,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(F16{}, F16{}, F16{}, F32{}, F16{}, Row{}, Col{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F16_F8_F16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F16{}, F8{}, F16{}, F32{}, F16{}, Row{}, Row{}, Row{}); @@ -198,7 +198,7 @@ int profile_gemm_universal(int argc, char* argv[]) { return profile(BF16{}, BF16{}, BF16{}, F32{}, BF16{}, Col{}, Row{}, Row{}); } -#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94) else if(data_type == GemmDataType::F8_F8_BF16 && layout == GemmMatrixLayout::MK_KN_MN) { return profile(F8{}, F8{}, F8{}, F32{}, BF16{}, Row{}, Row{}, Row{}); diff --git a/test/gemm_universal/test_gemm_universal_xdl.cpp b/test/gemm_universal/test_gemm_universal_xdl.cpp index 23b5c74ddd..b872d7089a 100644 --- a/test/gemm_universal/test_gemm_universal_xdl.cpp +++ b/test/gemm_universal/test_gemm_universal_xdl.cpp @@ -56,7 +56,7 @@ class TestGemmUniversal_KM_NK using KernelTypes_MK_KN = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) std::tuple< F16, F8, F16, F16>, std::tuple< F8, F16, F16, F16>, std::tuple< F8, F8, F8, BF16>, @@ -66,7 +66,7 @@ using KernelTypes_MK_KN = ::testing::Types< using KernelTypes_MK_NK = ::testing::Types< // ADataType, BDataType, ComputeDataType, CDataType std::tuple< F16, F16, F16, F16>, -#if defined(CK_ENABLE_FP8) && defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) +#if defined(CK_ENABLE_FP8) && (defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH) || defined(CK_USE_GFX94)) std::tuple< F16, F8, F16, F16>, std::tuple< F8, F16, F16, F16>, std::tuple< F8, F8, F8, BF16>, From efb34741fe1f6af938e32b80fa5a30211d8dd71c Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Fri, 15 Nov 2024 18:30:58 -0500 Subject: [PATCH 24/24] Bump rocm-docs-core from 1.8.3 to 1.8.4 in /docs/sphinx (#1670) Bumps [rocm-docs-core](https://github.com/ROCm/rocm-docs-core) from 1.8.3 to 1.8.4. - [Release notes](https://github.com/ROCm/rocm-docs-core/releases) - [Changelog](https://github.com/ROCm/rocm-docs-core/blob/v1.8.4/CHANGELOG.md) - [Commits](https://github.com/ROCm/rocm-docs-core/compare/v1.8.3...v1.8.4) --- updated-dependencies: - dependency-name: rocm-docs-core dependency-type: direct:production update-type: version-update:semver-patch ... Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- docs/sphinx/requirements.in | 2 +- docs/sphinx/requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/sphinx/requirements.in b/docs/sphinx/requirements.in index c2220e15db..9824df6266 100644 --- a/docs/sphinx/requirements.in +++ b/docs/sphinx/requirements.in @@ -1,2 +1,2 @@ -rocm-docs-core==1.8.3 +rocm-docs-core==1.8.4 sphinxcontrib-bibtex==2.6.3 diff --git a/docs/sphinx/requirements.txt b/docs/sphinx/requirements.txt index 0dc2e70c58..f89fbcf273 100644 --- a/docs/sphinx/requirements.txt +++ b/docs/sphinx/requirements.txt @@ -103,7 +103,7 @@ requests==2.32.3 # via # pygithub # sphinx -rocm-docs-core==1.8.3 +rocm-docs-core==1.8.4 # via -r requirements.in six==1.16.0 # via pybtex