mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
[GEMM] gemm_universal related optimization (#1453)
* replace buffer_atomic with global_atomic
* fixed global_atomic_add
* added bf16 atomic_add
* format
* clang-format-12
* clean
* clean
* add guards
* Update gtest.cmake
* enabled splitk_gemm_multi_d
* format
* add ckProfiler
* format
* fixed naming
* format
* clean
* clean
* add guards
* fix clang format
* format
* add kbatch printout
* clean
* Add rocm6.2 related gemm optimization
* Limit bf16 atomic usage
* remove redundant RCR gemm_universal instance
* Add RRR fp8 gemm universal instance
* Bug fix
* Add GPU_TARGET guard to FP8/BF8 target
* bug fix
* update cmake
* remove all fp8/bf8 example if arch not support
* Enable fp8 RRR support in ckProfiler
* limit greedy-reverse flag to gemm_universal in ckProfiler
---------
Co-authored-by: Jing Zhang <jizhan@fb.com>
Co-authored-by: Jing Zhang <jizhan@meta.com>
Co-authored-by: zjing14 <zhangjing14@gmail.com>
Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: illsilin <Illia.Silin@amd.com>
[ROCm/composable_kernel commit: 3049b5467c]
This commit is contained in:
@@ -5,17 +5,17 @@ if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp)
|
||||
target_link_libraries(client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
if((DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR NOT DEFINED DTYPES)
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
|
||||
@@ -4,5 +4,7 @@ target_link_libraries(client_grouped_conv2d_bwd_data PRIVATE composable_kernel::
|
||||
add_executable(client_grouped_conv3d_bwd_data grouped_conv3d_bwd_data.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_data PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_conv_operations)
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 grouped_conv3d_bwd_data_input_fp16_comp_bf8f8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_data_input_fp16_comp_bf8f8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
@@ -2,10 +2,13 @@ add_executable(client_grouped_conv1d_bwd_weight_fp16 grouped_conv1d_bwd_weight_f
|
||||
add_executable(client_grouped_conv2d_bwd_weight_fp16 grouped_conv2d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp16 grouped_conv3d_bwd_weight_fp16.cpp)
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp32 grouped_conv3d_bwd_weight_fp32.cpp)
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp)
|
||||
|
||||
target_link_libraries(client_grouped_conv1d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_grouped_conv2d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_weight_fp32 PRIVATE composable_kernel::device_conv_operations)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "bf8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8.cpp)
|
||||
target_link_libraries(client_grouped_conv3d_bwd_weight_fp16_comp_bf8_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
@@ -4,7 +4,7 @@ if((DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES)
|
||||
|
||||
endif()
|
||||
|
||||
if((DTYPES MATCHES "fp8") OR NOT DEFINED DTYPES)
|
||||
if((DTYPES MATCHES "fp8") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp)
|
||||
target_link_libraries(client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations)
|
||||
endif()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9" AND ((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR NOT DEFINED DTYPES))
|
||||
if((DTYPES MATCHES "fp8" AND DTYPES MATCHES "fp16") OR (NOT DEFINED DTYPES AND GPU_TARGETS MATCHES "gfx94"))
|
||||
add_executable(client_splitK_gemm splitK_gemm_fp16_f8.cpp)
|
||||
target_link_libraries(client_splitK_gemm PRIVATE composable_kernel::device_gemm_operations)
|
||||
endif()
|
||||
|
||||
@@ -34,8 +34,17 @@ if (DTYPES)
|
||||
endif()
|
||||
message("DTYPES macro set to ${DTYPES}")
|
||||
else()
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP8 -DCK_ENABLE_BF8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_ALL_DTYPES "ON")
|
||||
add_definitions(-DCK_ENABLE_INT8 -DCK_ENABLE_FP16 -DCK_ENABLE_FP32 -DCK_ENABLE_FP64 -DCK_ENABLE_BF16)
|
||||
set(CK_ENABLE_INT8 "ON")
|
||||
set(CK_ENABLE_FP16 "ON")
|
||||
set(CK_ENABLE_FP32 "ON")
|
||||
set(CK_ENABLE_FP64 "ON")
|
||||
set(CK_ENABLE_BF16 "ON")
|
||||
if (GPU_TARGETS MATCHES "gfx94")
|
||||
add_definitions(-DCK_ENABLE_FP8 -DCK_ENABLE_BF8)
|
||||
set(CK_ENABLE_FP8 "ON")
|
||||
set(CK_ENABLE_BF8 "ON")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GPU_TARGETS)
|
||||
|
||||
Reference in New Issue
Block a user